Skip to content

[mlir][TOSA] restore unrealized casts when lowering rescale ops #141096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2025

Conversation

AGindinson
Copy link
Contributor

@AGindinson AGindinson commented May 22, 2025

Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective.

The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal arith casts that are constrained to signless types. Whether the arith casts themselves are signed or unsigned should still depend on the rescale's *_unsigned attribute values.

Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts
between signed and signless types. However, explicitly unsigned types are still legal input
and output values from the TOSA IR perspective, and TF-to-TOSA's `ConvertUint8ToInt8` pass
generates just that.

The change adds back the casts when the unsigned<->signless semantics are explicit in the
underlying tensor types. This prevents the conversion routine from trying to generate illegal
`arith` casts that are contrained to signless types. Whether the `arith` casts themselves are
signed or unsigned should still depend on the rescale's `*_unsigned` attribute values.

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@AGindinson
Copy link
Contributor Author

@RoboTux, @GeorgeARM, could you please take a look?

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Artem Gindinson (AGindinson)

Changes

Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective, and TF-to-TOSA's ConvertUint8ToInt8 pass generates just that.

The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal arith casts that are contrained to signless types. Whether the arith casts themselves are signed or unsigned should still depend on the rescale's *_unsigned attribute values.


Full diff: https://github.com/llvm/llvm-project/pull/141096.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+15)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+133-30)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0b69cd2814fb9..6d73f23e2aae1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1492,6 +1492,15 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
                                                 : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
+          if (valueTy.isUnsignedInteger()) {
+            value = nestedBuilder
+                        .create<UnrealizedConversionCastOp>(
+                            nestedLoc,
+                            nestedBuilder.getIntegerType(
+                                valueTy.getIntOrFloatBitWidth()),
+                            value)
+                        .getResult(0);
+          }
           if (valueTy.getIntOrFloatBitWidth() < 32) {
             if (op.getInputUnsigned()) {
               value = nestedBuilder.create<arith::ExtUIOp>(
@@ -1537,6 +1546,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
                 value);
           }
 
+          if (outIntType.isUnsignedInteger()) {
+            value = nestedBuilder
+                        .create<UnrealizedConversionCastOp>(nestedLoc,
+                                                            outIntType, value)
+                        .getResult(0);
+          }
           nestedBuilder.create<linalg::YieldOp>(loc, value);
         });
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 185f1973ecdc6..8fc3d47d9f2ae 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1152,16 +1152,50 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
 // -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i8_unsigned_output
+// CHECK-LABEL: @rescale_i8_unsigned_output_explicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
+func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
+  // CHECK: linalg.yield [[TRUNC_ITOU]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
+
+  // CHECK: return
+  return
+}
+
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i8_unsigned_output_implicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
-  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1169,8 +1203,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
-  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
-  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[TRUNC]] : i8 to i8
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
@@ -1182,6 +1217,39 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output_implicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] : i48 to i48
+  // CHECK-DAG: [[C0:%.+]] = arith.constant 0
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -1230,19 +1298,52 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
 }
 
 // -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
+// CHECK-LABEL: @rescale_i8_unsigned_input_explicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
+  // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
+  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+
+  return
+}
+
+// -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i8_unsigned_input
+// CHECK-LABEL: @rescale_i8_unsigned_input_implicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
+func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C128:%.+]] = arith.constant 128
-  // CHECK: [[C22:%.+]] = arith.constant 22
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] : i8 to i8
+  // CHECK-DAG: [[C128:%.+]] = arith.constant 128
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1265,32 +1366,34 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
 // -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-LABEL: @rescale_i8_unsigned_input_output_explicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
-  // CHECK: [[C19689:%.+]] = arith.constant 19689
-  // CHECK: [[C15:%.+]] = arith.constant 15
+func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
-  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C0:%.+]] = arith.constant 0
-  // CHECK: [[C234:%.+]] = arith.constant 234
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
-  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
-  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
-  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xui8>)
+  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: ui8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
+  // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
+  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
-  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
-  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
-  // CHECK: linalg.yield [[TRUNC]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
+  // CHECK: linalg.yield [[TRUNC_ITOU]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
-  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
 
-  // CHECK: return
   return
 }
 

@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir

Author: Artem Gindinson (AGindinson)

Changes

Along with the changes to rescale op attributes, commit 7208649 dropped the builtin casts between signed and signless types. However, explicitly unsigned types are still legal input and output values from the TOSA IR perspective, and TF-to-TOSA's ConvertUint8ToInt8 pass generates just that.

The change adds back the casts when the unsigned<->signless semantics are explicit in the underlying tensor types. This prevents the conversion routine from trying to generate illegal arith casts that are contrained to signless types. Whether the arith casts themselves are signed or unsigned should still depend on the rescale's *_unsigned attribute values.


Full diff: https://github.com/llvm/llvm-project/pull/141096.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+15)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+133-30)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0b69cd2814fb9..6d73f23e2aae1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1492,6 +1492,15 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
                                                 : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
+          if (valueTy.isUnsignedInteger()) {
+            value = nestedBuilder
+                        .create<UnrealizedConversionCastOp>(
+                            nestedLoc,
+                            nestedBuilder.getIntegerType(
+                                valueTy.getIntOrFloatBitWidth()),
+                            value)
+                        .getResult(0);
+          }
           if (valueTy.getIntOrFloatBitWidth() < 32) {
             if (op.getInputUnsigned()) {
               value = nestedBuilder.create<arith::ExtUIOp>(
@@ -1537,6 +1546,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
                 value);
           }
 
+          if (outIntType.isUnsignedInteger()) {
+            value = nestedBuilder
+                        .create<UnrealizedConversionCastOp>(nestedLoc,
+                                                            outIntType, value)
+                        .getResult(0);
+          }
           nestedBuilder.create<linalg::YieldOp>(loc, value);
         });
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 185f1973ecdc6..8fc3d47d9f2ae 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1152,16 +1152,50 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
 // -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i8_unsigned_output
+// CHECK-LABEL: @rescale_i8_unsigned_output_explicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
+func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xi8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>)
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
+  // CHECK: linalg.yield [[TRUNC_ITOU]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
+
+  // CHECK: return
+  return
+}
+
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i8_unsigned_output_implicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i8_unsigned_output_implicit(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
-  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1169,8 +1203,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
-  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
-  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[TRUNC]] : i8 to i8
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
@@ -1182,6 +1217,39 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output_implicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output_implicit(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] : i48 to i48
+  // CHECK-DAG: [[C0:%.+]] = arith.constant 0
+  // CHECK-DAG: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -1230,19 +1298,52 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
 }
 
 // -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
+// CHECK-LABEL: @rescale_i8_unsigned_input_explicit
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i8_unsigned_input_explicit(%arg0 : tensor<2xui8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
+  // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
+  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+
+  return
+}
+
+// -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i8_unsigned_input
+// CHECK-LABEL: @rescale_i8_unsigned_input_implicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
+func.func @rescale_i8_unsigned_input_implicit(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C128:%.+]] = arith.constant 128
-  // CHECK: [[C22:%.+]] = arith.constant 22
+  // CHECK-NOT: builtin.unrealized_conversion_cast [[IN]] : i8 to i8
+  // CHECK-DAG: [[C128:%.+]] = arith.constant 128
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
@@ -1265,32 +1366,34 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
 // -----
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-LABEL: @rescale_i8_unsigned_input_output_explicit
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
-func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
-  // CHECK: [[C19689:%.+]] = arith.constant 19689
-  // CHECK: [[C15:%.+]] = arith.constant 15
+func.func @rescale_i8_unsigned_input_output_explicit(%arg0 : tensor<2xui8>) -> () {
+  // CHECK: [[C0:%.+]] = arith.constant 19689
+  // CHECK: [[C1:%.+]] = arith.constant 15
   // CHECK: [[INIT:%.+]] = tensor.empty()
-  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
-  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C0:%.+]] = arith.constant 0
-  // CHECK: [[C234:%.+]] = arith.constant 234
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
-  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
-  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
-  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xui8>) outs([[INIT]] : tensor<2xui8>)
+  // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: ui8):
+  // CHECK-DAG: [[C17:%.+]] = arith.constant 17
+  // CHECK-DAG: [[C22:%.+]] = arith.constant 22
+  // CHECK-DAG: [[IN_UTOI:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8
+  // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN_UTOI]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
-  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
-  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
-  // CHECK: linalg.yield [[TRUNC]]
+  // CHECK: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: [[TRUNC_ITOU:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
+  // CHECK: linalg.yield [[TRUNC_ITOU]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
-  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xui8>
 
-  // CHECK: return
   return
 }
 

@RoboTux
Copy link
Contributor

RoboTux commented May 23, 2025

Hi Artem,

First of all my apologies for the trouble my commit caused. How do you manage to lower those unrealized cast? When the only way to convey signedness was through signed/unsigned type we had unrealized_conversion_cast legalization failure for i8->ui8 when lowering from Arith to SPIR-V. I've just tried again with your patch and a rescale from i8 to ui8 and I'm getting:

rescale.mlir:6:8: error: failed to legalize operation 'builtin.unrealized_conversion_cast'

Is unrealized_conversion_cast the correct way to cast between signed/unsigned and signless type? If yes then it's just a bug in the SPIR-V conversion and this patch is fine.

@AGindinson
Copy link
Contributor Author

AGindinson commented May 23, 2025

Hi Thomas,

No problem, and very good that you've brought up SPIR-V - I haven't checked that, as I was using a non-SPIRV pipeline.

Is unrealized_conversion_cast the correct way to cast between signed/unsigned and signless type? If yes then it's just a bug in the SPIR-V conversion and this patch is fine.

I believe that's the correct way to view it, yes. Given the role of the Builtin dialect, it would be natural to expect successful conversions for the other dialects.
https://mlir.llvm.org/docs/Dialects/Builtin/#builtinunrealized_conversion_cast-unrealizedconversioncastop

If without the builtin casts the SPIR-V conversion succeeds and the signedness looks correct, I think it would suffice to mark the builtin op as legal wherever that gets triggered from. Like here: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp#L1358. Would you prefer I do this within the current PR or as a follow-up?

@AGindinson
Copy link
Contributor Author

@RoboTux Only now that I've understood it's actually ArithToSPIRV that fails for you. Is there an mlir-opt repro that I could try for your case?

@RoboTux
Copy link
Contributor

RoboTux commented May 24, 2025

@RoboTux Only now that I've understood it's actually ArithToSPIRV that fails for you. Is there an mlir-opt repro that I could try for your case?

Nevermind, it must be something with our pipeline. I realized we are using the Arith to SPIR-V pattern but not the pass and were missing to mark the op as legal. Still now it fails in SPIRVLowerABIAttributesPass where we do use the upstream pass. But if this op is the right way to do it then this is just a separate bug.

@RoboTux
Copy link
Contributor

RoboTux commented May 24, 2025

@RoboTux Only now that I've understood it's actually ArithToSPIRV that fails for you. Is there an mlir-opt repro that I could try for your case?

Nevermind, it must be something with our pipeline. I realized we are using the Arith to SPIR-V pattern but not the pass and were missing to mark the op as legal. Still now it fails in SPIRVLowerABIAttributesPass where we do use the upstream pass. But if this op is the right way to do it then this is just a separate bug.

FYI (but again this is not your responsability) adding a conversion pattern from UnrealizedConversionCast to SPIRV's BitcastOp seems to solve my issues. Otherwise if we keep UnrealizedConversionCast then we need to teach the serializer about it but since SPIRV's bitcast seems to accept signedness cast I think we should just use it.

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
@AGindinson
Copy link
Contributor Author

@RoboTux Thanks Thomas for the additional context! Explicitly lowering the builtin casts into SPIRV Bitcast ops seems correct and consistent.

@AGindinson AGindinson requested a review from RoboTux May 26, 2025 07:29
Copy link
Contributor

@RoboTux RoboTux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AGindinson
Copy link
Contributor Author

@RoboTux Could you please help with merging this? I'm yet to become eligible for write access :/

@RoboTux RoboTux merged commit d03f30f into llvm:main May 26, 2025
7 checks passed
Copy link

@AGindinson Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@AGindinson AGindinson deleted the unsigned-rescale-lowering branch May 26, 2025 11:53
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…#141096)

Along with the changes to rescale op attributes, commit 7208649 dropped
the builtin casts between signed and signless types. However, explicitly
unsigned types are still legal input and output values from the TOSA IR
perspective.

The change adds back the casts when the unsigned<->signless semantics
are explicit in the underlying tensor types. This prevents the
conversion routine from trying to generate illegal `arith` casts that
are constrained to signless types. Whether the `arith` casts themselves
are signed or unsigned should still depend on the rescale's `*_unsigned`
attribute values.

---------

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants