diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index a6d6da13cd2c5..b8983dae1dd9c 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -37,6 +37,8 @@ void populateTosaFoldConstantClampPatterns(MLIRContext *ctx, void populateTosaFoldConstantCastPatterns(MLIRContext *ctx, RewritePatternSet &patterns, bool enableIntCastFolding); +void populateTosaFoldConstantMulPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); void populateTosaFoldConstantPowPatterns(MLIRContext *ctx, RewritePatternSet &patterns); void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx, diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 1b6a3530d6a14..cbfec4a9890d1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaFoldConstantAdd.cpp TosaFoldConstantCast.cpp TosaFoldConstantClamp.cpp + TosaFoldConstantMul.cpp TosaFoldConstantPow.cpp TosaFoldConstantReciprocal.cpp TosaFoldConstantRSQRT.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp new file mode 100644 index 0000000000000..58b1b1d7a5cf5 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantMul.cpp @@ -0,0 +1,118 @@ +//===- TosaFoldConstantMul.cpp --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Mul operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include +#include + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantMul : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MulOp mulOp, + PatternRewriter &rewriter) const override { + if (mulOp.getShift() > 0) { + return rewriter.notifyMatchFailure( + mulOp, "Non-zero shift folding is currently not implemented."); + } + + auto leftOp = mulOp.getInput1(); + auto rightOp = mulOp.getInput2(); + + // Check if both tensors are constant + auto rhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(leftOp, mulOp, rewriter); + if (failed(rhsIsConstantCheck)) { + return rhsIsConstantCheck; + } + auto lhsIsConstantCheck = + notifyIfNoTosaDenseConstantTensor(rightOp, mulOp, rewriter); + if (failed(lhsIsConstantCheck)) { + return lhsIsConstantCheck; + } + + // Extract the tensor values + DenseElementsAttr lhsValues; + matchPattern(leftOp, m_Constant(&lhsValues)); + + DenseElementsAttr rhsValues; + matchPattern(rightOp, m_Constant(&rhsValues)); + + if (!constantBinaryOpShouldBeFolded(mulOp, lhsValues, rhsValues)) { + return rewriter.notifyMatchFailure( + mulOp, "Currently, muls will only be folded if this requires only " + "little additional memory usage."); + } + + DenseElementsAttr newTensor; + + auto lhsElemType = leftOp.getType().getElementType(); + auto rhsElemType = rightOp.getType().getElementType(); + assert(lhsElemType == rhsElemType); + + auto resultType = mulOp.getType(); + auto resultElementType = resultType.getElementType(); + if (isa(lhsElemType)) { + assert(isa(rhsElemType) && + isa(resultElementType)); + auto resultElementWidth = resultElementType.getIntOrFloatBitWidth(); + assert(resultElementWidth >= lhsElemType.getIntOrFloatBitWidth() && + "The multiplication is expected to have an at least as big output " + "as input type"); + + // Compute the multiplication and track if an overflow occurred to enable + // emitting a warning + bool mulOverflowed = false; + auto intMulFun = [&resultElementWidth, &mulOverflowed]( + const APInt &first, const APInt &second) { + bool didOverflow; + auto res = first.sext(resultElementWidth) + .smul_ov(second.sext(resultElementWidth), didOverflow); + mulOverflowed |= didOverflow; + return res; + }; + newTensor = applyElementWise(lhsValues, rhsValues, + resultType, intMulFun); + if (mulOverflowed) { + mulOp.emitWarning( + "Multiplication did overflow. The results are unspecified."); + } + } else { + assert(isa(lhsElemType) && isa(rhsElemType) && + isa(resultType.getElementType())); + auto mulFun = [](const APFloat &first, const APFloat &second) { + return first * second; + }; + newTensor = applyElementWise(lhsValues, rhsValues, + resultType, mulFun); + } + rewriter.replaceOpWithNewOp(mulOp, newTensor.getType(), newTensor); + + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantMulPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index 1f230b6b7d1e5..0cee7680453b6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -54,6 +54,7 @@ struct TosaLayerwiseConstantFoldPass mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns, enableIntCastFolding); mlir::tosa::populateTosaFoldConstantClampPatterns(ctx, patterns); + mlir::tosa::populateTosaFoldConstantMulPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns); mlir::tosa::populateTosaFoldConstantRSQRTPatterns(ctx, patterns); diff --git a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir new file mode 100644 index 0000000000000..8bd94a9266a73 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir @@ -0,0 +1,185 @@ +// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s + +// Float multiplications + +// CHECK-LABEL: @mul_fold_float +func.func @mul_fold_float() -> tensor<4xf16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2.32{{.*}}e+03, -1.49{{.*}}e+01, -0.{{0*}}e+00, -0.{{0*}}e+00 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17.4978, 4.9882, 0.0, -0.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %1 = "tosa.const"() {value = + dense<[-132.7, -3.0, -0.0, 5.0]> : + tensor<4xf16> + } : () -> tensor<4xf16> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + return %2 : tensor<4xf16> +} + +// CHECK-LABEL: @mul_fold_float_infinity_nan +func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0x7F800000, 0xFF800000, 0xFF800000, 0x7FC00000, 0xFF800000, 0x7FC00000 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0x7F800000, 0xFF800000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x7F800000, 0xFF800000]> : + tensor<7xf32> + } : () -> tensor<7xf32> + %1 = "tosa.const"() {value = + dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> : + tensor<7xf32> + } : () -> tensor<7xf32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32> + return %2 : tensor<7xf32> +} + +// CHECK-LABEL: @add_fold_float_overflow +func.func @add_fold_float_overflow() -> tensor<2xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0x7F800000, 0xFF800000 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[3.1e+38, -3.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %1 = "tosa.const"() {value = + dense<[2.1e+38, 1.1e+38]> : + tensor<2xf32> + } : () -> tensor<2xf32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return %2 : tensor<2xf32> +} + +// ----- +// Int multiplications + +// CHECK-LABEL: @mul_fold_int +func.func @mul_fold_int() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}2244, -12, 0, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-132, -3, 0, 5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: @mul_fold_i8 +func.func @mul_fold_i8() -> tensor<4xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -12, 0, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, -2, 0]> : + tensor<4xi8> + } : () -> tensor<4xi8> + %1 = "tosa.const"() {value = + dense<[-12, -3, 0, 5]> : + tensor<4xi8> + } : () -> tensor<4xi8> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: @mul_fold_int_overflow +func.func @mul_fold_int_overflow() -> tensor<4xi32> { + // Don't expect any specific results for the overflowing multiplication, just + // that it is folded. + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[2147483647, 2147483640, -2147483648, -2147483640]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[1, 10, 1, 30]> : + tensor<4xi32> + } : () -> tensor<4xi32> + // expected-warning@below {{Multiplication did overflow. The results are unspecified.}} + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +} + +// ----- +// self-multiplication + +// CHECK-LABEL: @mul_fold_equal_args +func.func @mul_fold_equal_args() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}289, 16, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %2 = "tosa.mul"(%0, %0) {shift = 0 : i32} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// ----- +// Broadcasted multiplications + +// CHECK-LABEL: @mul_fold_int_broadcast_simple +func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}204, -48, 0 + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0]> : + tensor<3xi32> + } : () -> tensor<3xi32> + %1 = "tosa.const"() {value = + dense<-12> : + tensor<1xi32> + } : () -> tensor<1xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %2 : tensor<3xi32> +} + +// CHECK-LABEL: @mul_fold_int_broadcast_complex +func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> { + // CHECK: [[RES:]] ={{.*}}tosa.const + // CHECK-SAME{LITERAL}: [[204, -119, -68], + // CHECK-SAME{LITERAL}: [-12, 7, 4], + // CHECK-SAME{LITERAL}: [-228, 133, 76]] + // CHECK-NOT: tosa.mul + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[[-17], [1], [19]]> : + tensor<3x1xi32> + } : () -> tensor<3x1xi32> + %1 = "tosa.const"() {value = + dense<[[-12, 7, 4]]> : + tensor<1x3xi32> + } : () -> tensor<1x3xi32> + %2 = "tosa.mul"(%0, %1) {shift = 0 : i32} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32> + return %2 : tensor<3x3xi32> +} + +// CHECK-LABEL: @mul_fold_int_non_zero_shift +func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> { + // CHECK: [[FIRST:]] ={{.*}}tosa.const + // CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const + // CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]] + // CHECK-NEXT: return [[MUL]] + %0 = "tosa.const"() {value = + dense<[-17, 4, 0, 0]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %1 = "tosa.const"() {value = + dense<[-132, -3, 0, 5]> : + tensor<4xi32> + } : () -> tensor<4xi32> + %2 = "tosa.mul"(%0, %1) {shift = 1 : i32} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %2 : tensor<4xi32> +}