From 81c13b6af57d7ce15290d82f04438dc394a129e4 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Mon, 24 Apr 2023 16:29:24 +0100 Subject: [PATCH 1/3] Implement folding for constant tosa.muls * Folds multiplications with constant operands (limited to muls with shift = 0) * Add unit test for the folding * Implement saturating semantics for i32 overflows, might require changes if the spec clarification comes in [0] [0] https://discuss.mlplatform.org/t/integer-multiplication-overflow-handling/187 --- .../mlir/Dialect/Tosa/Transforms/Passes.h | 2 + .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Tosa/Transforms/TosaFoldConstantMul.cpp | 114 ++++++++++++ .../TosaLayerwiseConstantFoldPass.cpp | 1 + mlir/test/Dialect/Tosa/constant-mul-opt.mlir | 167 ++++++++++++++++++ 5 files changed, 285 insertions(+) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp create mode 100644 mlir/test/Dialect/Tosa/constant-mul-opt.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index a6d6da13cd2c5..b8983dae1dd9c 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -37,6 +37,8 @@ void populateTosaFoldConstantClampPatterns(MLIRContext *ctx, void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding); +void populateTosaFoldConstantMulPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 1b6a3530d6a14..cbfec4a9890d1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaFoldConstantAdd.cpp TosaFoldConstantCast.cpp TosaFoldConstantClamp.cpp + TosaFoldConstantMul.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp new file mode 100644 index 0000000000000..73d957bd5c9dc --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp @@ -0,0 +1,114 @@ +//===- TosaFoldConstantMul.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Mul operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantMul : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MulOp mulOp, + PatternRewriter &rewriter) const override { + if (mulOp.getShift() > 0) { + return rewriter.notifyMatchFailure( + mulOp, "Non-zero shift folding is currently not implemented."); + } + + auto leftOp = mulOp.getInput1(); + auto rightOp = mulOp.getInput2(); + + // Check if both tensors are constant + auto rhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(leftOp, mulOp, rewriter); + if (failed(rhsIsConstantCheck)) { + return rhsIsConstantCheck; + } + auto lhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(rightOp, mulOp, rewriter); + if (failed(lhsIsConstantCheck)) { + return lhsIsConstantCheck; + } + + // Extract the tensor values + DenseElementsAttr lhsValues; + matchPattern(leftOp, m_Constant(&lhsValues)); + + DenseElementsAttr rhsValues; + matchPattern(rightOp, m_Constant(&rhsValues)); + + if (!constantBinaryOpShouldBeFolded(mulOp, lhsValues, rhsValues)) { + return rewriter.notifyMatchFailure( + mulOp, "Currently, muls will only be folded if this requires only " + "little additional memory usage."); + } + + DenseElementsAttr newTensor; + + auto lhsElemType = leftOp.getType().getElementType(); + auto rhsElemType = rightOp.getType().getElementType(); + assert(lhsElemType == rhsElemType); + + auto resultType = mulOp.getType(); + auto resultElementType = resultType.getElementType(); + if (isa(lhsElemType)) { + assert(isa(rhsElemType) && + isa(resultElementType)); + auto resultElementWidth = resultElementType.getIntOrFloatBitWidth(); + assert(resultElementWidth == 32 && + "All integer multiplications in TOSA are specified to result in " + "32 bit width"); + // TODO: To implement shifts > 0, capture the shift value stored in the + // mul here + auto intMulFun = [&resultElementWidth](const APInt &first, + const APInt &second) { + // TODO the documentation has conflicting definitions for the behavior + // of overflows + // The sign extend should always be valid as the result type is required + // to be i32 and all other integer input types are smaller or equal + // to 32. + return first.sext(resultElementWidth) + .smul_sat(second.sext(resultElementWidth)); + }; + newTensor = applyElementWise(lhsValues, rhsValues, + resultType, intMulFun); + } else { + assert(isa(lhsElemType) && isa(rhsElemType) && + isa(resultType.getElementType())); + auto mulFun = [](const APFloat &first, const APFloat &second) { + return first * second; + }; + newTensor = applyElementWise(lhsValues, rhsValues, + resultType, mulFun); + } + rewriter.replaceOpWithNewOp(mulOp, newTensor.getType(), newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantMulPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 1f230b6b7d1e5..0cee7680453b6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -54,6 +54,7 @@ struct TosaLayerwiseConstantFoldPass mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, enableIntCastFolding); mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns); + mlir::tosa::populateTosaFoldConstantMulPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir new file mode 100644 index 0000000000000..66d2291afec87 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir @@ -0,0 +1,167 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// Float multiplications + +// CHECK-LABEL: @mul_fold_float +func.func @mul_fold_float() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.32{{.*}}e+03, -1.49{{.*}}e+01, -0.{{0*}}e+00, -0.{{0*}}e+00 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17.4978, 4.9882, 0.0, -0.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %1 = "tosa.const"() {value = + dense<[-132.7, -3.0, -0.0, 5.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + return %2 : tensor<4xf16> +} + +// CHECK-LABEL: @mul_fold_float_infinity_nan +func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0x7F800000, 0xFF800000, 0xFF800000, 0x7FC00000, 0xFF800000, 0x7FC00000 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000, 0xFF800000]> : + tensor<7xf32> + } : () -> tensor<7xf32> + %1 = "tosa.const"() {value = + dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> : + tensor<7xf32> + } : () -> tensor<7xf32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32> + return %2 : tensor<7xf32> +} + +// CHECK-LABEL: @add_fold_float_overflow +func.func @add_fold_float_overflow() -> tensor<2xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[3.1e+38, -3.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %1 = "tosa.const"() {value = + dense<[2.1e+38, 1.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- +// Int multiplications + +// CHECK-LABEL: @mul_fold_int +func.func @mul_fold_int() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2244, -12, 0, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-132, -3, 0, 5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// ----- +// self-multiplication + +// CHECK-LABEL: @mul_fold_int_overflow +// TODO: Change expected behavior if the tosa.mul on i32 should not be +// saturating. Also add a test with different widths in that case. +func.func @mul_fold_int_overflow() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2147483647, 2147483647, -2147483648, -2147483648 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[2147483647, 2147483640, -2147483648, -2147483640]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[1, 10, 1, 30]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: @mul_fold_equal_args +func.func @mul_fold_equal_args() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %2 = "tosa.mul"(%0, %0) {shift = 0 : i32} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// ----- +// Broadcasted multiplications + +// CHECK-LABEL: @mul_fold_int_broadcast_simple +func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -48, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %1 = "tosa.const"() {value = + dense<-12> : + tensor<1xi32> + } : () -> tensor<1xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// CHECK-LABEL: @mul_fold_int_broadcast_complex +func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[204, -119, -68], + // CHECK-SAME{LITERAL}: [-12, 7, 4], + // CHECK-SAME{LITERAL}: [-228, 133, 76]] + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[[-17], [1], [19]]> : + tensor<3x1xi32> + } : () -> tensor<3x1xi32> + %1 = "tosa.const"() {value = + dense<[[-12, 7, 4]]> : + tensor<1x3xi32> + } : () -> tensor<1x3xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + return %2 : tensor<3x3xi32> +} + +// CHECK-LABEL: @mul_fold_int_non_zero_shift +func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> { + // CHECK: [[FIRST:]] ={{.*}}tosa.const + // CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const + // CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]] + // CHECK-NEXT: return [[MUL]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-132, -3, 0, 5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} From 3326d8e6d7cf979825f64de138aa73ace250f73c Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 26 Apr 2023 10:51:06 +0100 Subject: [PATCH 2/3] Update to unspecified overflow behavior of int mul * Do not use saturating multiplication * Emit a warning if overflows occur * Check the behavior in a test * Add a test with a small bit width which would overflow if the mul wouldn't always result a 32-bit int --- .../Tosa/Transforms/TosaFoldConstantMul.cpp | 27 ++++++++++-------- mlir/test/Dialect/Tosa/constant-mul-opt.mlir | 28 +++++++++++++++---- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp index 73d957bd5c9dc..61a9551ca28ea 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp @@ -74,23 +74,26 @@ struct TosaFoldConstantMul : public OpRewritePattern { assert(isa(rhsElemType) && isa(resultElementType)); auto resultElementWidth = resultElementType.getIntOrFloatBitWidth(); - assert(resultElementWidth == 32 && - "All integer multiplications in TOSA are specified to result in " - "32 bit width"); + assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() && + "The multiplication is expected to have an at least a big output " + "as input type"); // TODO: To implement shifts > 0, capture the shift value stored in the // mul here - auto intMulFun = [&resultElementWidth](const APInt &first, - const APInt &second) { - // TODO the documentation has conflicting definitions for the behavior - // of overflows - // The sign extend should always be valid as the result type is required - // to be i32 and all other integer input types are smaller or equal - // to 32. - return first.sext(resultElementWidth) - .smul_sat(second.sext(resultElementWidth)); + bool mulOverflowed; + auto intMulFun = [&resultElementWidth, &mulOverflowed]( + const APInt &first, const APInt &second) { + bool didOverflow; + auto res = first.sext(resultElementWidth) + .smul_ov(second.sext(resultElementWidth), didOverflow); + mulOverflowed |= didOverflow; + return res; }; newTensor = applyElementWise(lhsValues, rhsValues, resultType, intMulFun); + if (mulOverflowed) { + mulOp.emitWarning( + "Multiplication did overflow. The results are unspecified."); + } } else { assert(isa(lhsElemType) && isa(rhsElemType) && isa(resultType.getElementType())); diff --git a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir index 66d2291afec87..8bd94a9266a73 100644 --- a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir @@ -73,14 +73,28 @@ func.func @mul_fold_int() -> tensor<4xi32> { return %2 : tensor<4xi32> } -// ----- -// self-multiplication +// CHECK-LABEL: @mul_fold_i8 +func.func @mul_fold_i8() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -12, 0, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, -2, 0]> : + tensor<4xi8> + } : () -> tensor<4xi8> + %1 = "tosa.const"() {value = + dense<[-12, -3, 0, 5]> : + tensor<4xi8> + } : () -> tensor<4xi8> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} // CHECK-LABEL: @mul_fold_int_overflow -// TODO: Change expected behavior if the tosa.mul on i32 should not be -// saturating. Also add a test with different widths in that case. func.func @mul_fold_int_overflow() -> tensor<4xi32> { - // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2147483647, 2147483647, -2147483648, -2147483648 + // Don't expect any specific results for the overflowing multiplication, just + // that it is folded. + // CHECK: [[RES:]] ={{.*}}tosa.const // CHECK-NOT: tosa.mul // CHECK: return [[RES]] %0 = "tosa.const"() {value = @@ -91,10 +105,14 @@ func.func @mul_fold_int_overflow() -> tensor<4xi32> { dense<[1, 10, 1, 30]> : tensor<4xi32> } : () -> tensor<4xi32> + // expected-warning@below {{Multiplication did overflow. The results are unspecified.}} %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %2 : tensor<4xi32> } +// ----- +// self-multiplication + // CHECK-LABEL: @mul_fold_equal_args func.func @mul_fold_equal_args() -> tensor<3xi32> { // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0 From 75b6df8a29269ba938a1833b0cd68dbee8bd675d Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Wed, 26 Apr 2023 12:51:05 +0100 Subject: [PATCH 3/3] Minor fixes * Fix typo * initialize bool to false * remove redundant comment and put a small explanation instead --- mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp index 61a9551ca28ea..58b1b1d7a5cf5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp @@ -75,11 +75,12 @@ struct TosaFoldConstantMul : public OpRewritePattern { isa(resultElementType)); auto resultElementWidth = resultElementType.getIntOrFloatBitWidth(); assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() && - "The multiplication is expected to have an at least a big output " + "The multiplication is expected to have an at least as big output " "as input type"); - // TODO: To implement shifts > 0, capture the shift value stored in the - // mul here - bool mulOverflowed; + + // Compute the multiplication and track if an overflow occurred to enable + // emitting a warning + bool mulOverflowed = false; auto intMulFun = [&resultElementWidth, &mulOverflowed]( const APInt &first, const APInt &second) { bool didOverflow;