Skip to content

[mlir][tosa] Require signless types in validation and add corresponding conversion pass #144367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lhutton1
Copy link
Contributor

Firstly, this commit requires that all types are signless in the strict mode of the validation pass. This is because signless types on operations are required by the TOSA specification. The "strict" mode in the validation pass is the final check for TOSA conformance to the specification, which can often be used for conversion to other formats.

In addition, a conversion pass --tosa-convert-integer-type-to-signless is provided to allow a user to convert all integer types to signless. The intention is that this pass can be run before the validation pass. Following use of this pass, input/output information should be carried independently by the user.

…ng conversion pass

Firstly, this commit requires that all types are signless in the strict
mode of the validation pass. This is because signless types on operations
are required by the TOSA specification. The "strict" mode in the
validation pass is the final check for TOSA conformance to the
specification, which can often be used for conversion to other formats.

In addition, a conversion pass `--tosa-convert-integer-type-to-signless`
is provided to allow a user to convert all integer types to signless.
The intention is that this pass can be run before the validation pass.
Following use of this pass, input/output information should be carried
independently by the user.

Change-Id: Id7aebf0071c9a7516c77f55062db82760c0da533
@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

Firstly, this commit requires that all types are signless in the strict mode of the validation pass. This is because signless types on operations are required by the TOSA specification. The "strict" mode in the validation pass is the final check for TOSA conformance to the specification, which can often be used for conversion to other formats.

In addition, a conversion pass --tosa-convert-integer-type-to-signless is provided to allow a user to convert all integer types to signless. The intention is that this pass can be run before the validation pass. Following use of this pass, input/output information should be carried independently by the user.


Full diff: https://github.com/llvm/llvm-project/pull/144367.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+14)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp (+134)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+5-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+2)
  • (added) mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir (+73)
  • (added) mlir/test/Dialect/Tosa/tosa-validation-valid.mlir (+31)
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index d005a4cc6859c..b96682843538c 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -127,4 +127,18 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> {
   }];
 }
 
+def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> {
+  let summary = "Convert integer types to signless";
+  let description = [{
+    This pass converts signed or unsigned integer types to signless. It
+    currently does this greedily for all operators and can also change the
+    signature of the function. Should the signature of the entrypoint
+    function change, it will be the responsibility of the user to carry
+    signedness information of the inputs and outputs independently.
+
+    This can be a useful transformation for conversion to other formats
+    that require strict adherence to the TOSA specification.
+  }];
+}
+
 #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
index bbf079faea3d0..803993bb1008d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRTosaTransforms
+  TosaConvertIntegerTypeToSignless.cpp
   TosaDecomposeTransposeConv.cpp
   TosaDecomposeDepthwise.cpp
   TosaFolders.cpp
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
new file mode 100644
index 0000000000000..3085e56ceebc0
--- /dev/null
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -0,0 +1,134 @@
+//===- TosaConvertIntegerTypeToSignless.cpp
+//-------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===-------------------------------------------------------------------------------===//
+
+// -----------
+// Motivation:
+// -----------
+
+// The TOSA specification uses a signless type system, which means that
+// information about signedness must be encapsulated by the operations
+// themselves. For example, tosa.rescale provides the attrbutes `input_unsigned`
+// and `output_unsigned` to indicate whether the input/output should be
+// interpreted as unsigned or signed.
+
+// The TOSA dialect, on the other hand, allows the use of signed or unsigned
+// types in addition to signless. As such, when converting from TOSA dialect to
+// other formats, we need to ensure that we conform to the TOSA specification.
+
+// ---------
+// Overview:
+// ---------
+
+// This pass converts signed or unsigned integer types to signless. It currently
+// does this greedily for all operators and can also change the signature of the
+// function. Should the signature of the entrypoint function change, it will be
+// the responsibility of the user to carry signedness information of the inputs
+// and outputs independently.
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace tosa {
+
+#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
+#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+
+namespace {
+class ToSignlessTensorTypeConverter : public TypeConverter {
+  static Type convertType(Type type) {
+    const auto tensorType = dyn_cast<TensorType>(type);
+    if (!tensorType)
+      return type;
+
+    const auto intType = dyn_cast<IntegerType>(tensorType.getElementType());
+    if (!intType ||
+        intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
+      return type;
+
+    const auto signlessType = IntegerType::get(
+        intType.getContext(), intType.getWidth(), IntegerType::Signless);
+    return tensorType.cloneWith(std::nullopt, signlessType);
+  }
+
+public:
+  explicit ToSignlessTensorTypeConverter() { addConversion(convertType); }
+};
+
+class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
+public:
+  ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
+                                        MLIRContext *context)
+      : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // Convert integer types to signless
+    SmallVector<Type, 4> resultTypes;
+    if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
+      return failure();
+
+    // Create new op with replaced operands and results
+    auto *newOp = Operation::create(
+        op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
+        op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+
+    // Handle regions in e.g. tosa.cond_if and tosa.while_loop
+    for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
+      Region &before = std::get<0>(regions);
+      Region &parent = std::get<1>(regions);
+      rewriter.inlineRegionBefore(before, parent, parent.end());
+      if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
+        return failure();
+    }
+
+    // Replace with rewritten op
+    rewriter.insert(newOp);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
+  }
+};
+
+class TosaConvertIntegerTypeToSignless
+    : public impl::TosaConvertIntegerTypeToSignlessBase<
+          TosaConvertIntegerTypeToSignless> {
+public:
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    ConversionTarget target(*context);
+    ToSignlessTensorTypeConverter typeConverter;
+
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return typeConverter.isSignatureLegal(op.getFunctionType()) &&
+             typeConverter.isLegal(&op.getBody());
+    });
+    target.markUnknownOpDynamicallyLegal([&](Operation *op) {
+      return typeConverter.isLegal(op->getOperandTypes()) &&
+             typeConverter.isLegal(op->getResultTypes());
+    });
+
+    RewritePatternSet patterns(context);
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+        patterns, typeConverter);
+    patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
+
+    if (failed(
+            applyFullConversion(getOperation(), target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 229f42d3178b5..3f27849b8c90c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1320,13 +1320,14 @@ void TosaValidation::runOnOperation() {
 
     // validate operator element types:
     // - rescale operator is allowed to have ui8/ui16/ui32
-    //   operands/results
+    //   operands/results when strictOpSpecAlignment is false
     // - perform valid element type check at the beginning to
     //   protect rest of code against quantized element types
-    const bool opIsRescale = isa<tosa::RescaleOp>(op);
+    const bool allowUnsigned =
+        !strictOpSpecAlignment && isa<tosa::RescaleOp>(op);
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
-      if (!isValidElementType(elementTy, opIsRescale)) {
+      if (!isValidElementType(elementTy, allowUnsigned)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
@@ -1334,7 +1335,7 @@ void TosaValidation::runOnOperation() {
     }
     for (Type resultTy : op->getResultTypes()) {
       auto elementTy = getElementTypeOrSelf(resultTy);
-      if (!isValidElementType(elementTy, opIsRescale)) {
+      if (!isValidElementType(elementTy, allowUnsigned)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 805522799a6d8..e25b3b7ef3e3a 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -2000,6 +2000,7 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8
   %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
   %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
   %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
   %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
   return %r : tensor<1x1xi8>
 }
@@ -2012,6 +2013,7 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
   %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
   %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
   %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error@+1 {{'tosa.rescale' op is not profile-aligned: element type 'ui8' is not legal}}
   %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
   return %r : tensor<1x1xui8>
 }
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
new file mode 100644
index 0000000000000..38ac8d8fb66d9
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt --split-input-file --tosa-convert-integer-type-to-signless %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+// CHECK: %arg0: tensor<1x1xi8>
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+  // CHECK: return %[[RESCALE]] : tensor<1x1xi8>
+  return %r : tensor<1x1xui8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+// CHECK: %arg0: tensor<1x1xi16>
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<32768> : tensor<1xi16>}> : () -> tensor<1xi16>
+  // CHECK: %[[RESCALE:.*]] = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x1xi8>
+  // CHECK: return %[[RESCALE]] : tensor<1x1xi8>
+  return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_unsigned_function_signature
+// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
+func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
+  // CHECK: return %arg0, %arg1 : tensor<1xi8>, tensor<1xi8>
+  return %arg0, %arg1 : tensor<1xui8>, tensor<1xui8>
+}
+
+// -----
+
+// CHECK-LABEL: test_no_change
+// CHECK: %arg0: tensor<13x21x3xi8>
+func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
+  // CHECK: return %0 : tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_regions
+// CHECK: %arg0: tensor<i8>, %arg1: tensor<i8>
+func.func @test_regions(%arg0: tensor<ui8>, %arg1: tensor<ui8>, %arg2: tensor<i1>) -> tensor<ui8> {
+  // CHECK: tosa.cond_if %arg2 -> (tensor<i8>)
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
+    // CHECK: %1 = tosa.add %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
+    %1 = tosa.add %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+    // CHECK: tosa.yield %1 : tensor<i8>
+    tosa.yield %1 : tensor<ui8>
+  },  {
+  ^bb0(%arg3: tensor<ui8>, %arg4: tensor<ui8>):
+    // CHECK: %1 = tosa.sub %arg0, %arg1 : (tensor<i8>, tensor<i8>) -> tensor<i8>
+    %1 = tosa.sub %arg0, %arg1 : (tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+    // CHECK: tosa.yield %1 : tensor<i8>
+    tosa.yield %1 : tensor<ui8>
+  }) : (tensor<i1>, tensor<ui8>, tensor<ui8>) -> tensor<ui8>
+  // CHECK: return %0 : tensor<i8>
+  return %0 : tensor<ui8>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
new file mode 100644
index 0000000000000..cab14201dc0ce
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-validation-valid.mlir
@@ -0,0 +1,31 @@
+//--------------------------------------------------------------------------------------------------
+// Test valid IR in terms of the shape and type of tensor, and the argument type of
+// operation. Excludes the profile compilance checking since it is performed earlier in the
+// validation flow.
+//--------------------------------------------------------------------------------------------------
+
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+  return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+  return %r : tensor<1x1xui8>
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants