diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 9c7b00b660ba7..4933cafa41ed6 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -732,7 +732,8 @@ function(mlir_target_link_libraries target type) endif() if (MLIR_LINK_MLIR_DYLIB) - target_link_libraries(${target} ${type} MLIR) + # AMD: Do not link shared, as this casues linking errors + target_link_libraries(${target} ${type} ${ARGN}) else() target_link_libraries(${target} ${type} ${ARGN}) endif() diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 815f65b106d94..404002e03b51b 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -23,6 +23,11 @@ # grouping. Source groupings form a DAG. # SOURCES: List of specific source files relative to ROOT_DIR to include. # SOURCES_GLOB: List of glob patterns relative to ROOT_DIR to include. + +if (POLICY CMP0175) + cmake_policy(SET CMP0175 OLD) +endif() + function(declare_mlir_python_sources name) cmake_parse_arguments(ARG "" diff --git a/mlir/include/mlir/Analysis/AffineExprBounds.h b/mlir/include/mlir/Analysis/AffineExprBounds.h new file mode 100644 index 0000000000000..9ffd7227d973f --- /dev/null +++ b/mlir/include/mlir/Analysis/AffineExprBounds.h @@ -0,0 +1,91 @@ +//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines an analysis of affine expressions to compute their +// ranges (lower/upper bounds) in a given context. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H +#define MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H + +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" + +using namespace mlir; + +/// This visitor computes the bounds of affine expressions, using as context the +/// bounds of the dimensions of the expression. +/// +/// Example: +/// Given bounds 0 <= d0 <= 99 and 0 <= d1 <= 199, we can compute the bounds +/// of the following expression: +/// lb(2 * d0 + 3 * d1) = 0 +/// ub(2 * d0 + 3 * d1) = 795 +/// +/// * The bounds given in the context are inclusive, and the bounds returned +/// are also inclusive. +/// * If bounds are not available for a dimension, std::nullopt can be used +/// instead. The bounds of an expression that involves it will be std::nullopt. +/// * Limitations: +/// - Parametric expressions (using symbols) are not supported. +/// - Unsigned FloorDiv is currently not supported. +class AffineExprBoundsVisitor + : public AffineExprVisitor { +public: + /// Initialize the context (bounds) with APInt. All bounds must have the same + /// signedness and bit width. + AffineExprBoundsVisitor(ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, + bool boundsSigned, uint64_t bitWidth, + MLIRContext *context); + + /// Initialize the context (bounds) with 64-bit signed integers. This allows + /// to directly map index-type values such as Linalg op bounds, which are + /// represented as int64_t. + AffineExprBoundsVisitor(ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, + MLIRContext *context); + + /// Get the upper bound of \p expr using the context bounds. + std::optional getUpperBound(AffineExpr expr); + std::optional getIndexUpperBound(AffineExpr expr); + + /// Get the lower bound of \p expr using the context bounds. + std::optional getLowerBound(AffineExpr expr); + std::optional getIndexLowerBound(AffineExpr expr); + + // These methods are directly called by the AffineExprVisitor base class. + LogicalResult visitMulExpr(AffineBinaryOpExpr expr); + LogicalResult visitAddExpr(AffineBinaryOpExpr expr); + LogicalResult visitDimExpr(AffineDimExpr expr); + LogicalResult visitSymbolExpr(AffineSymbolExpr expr); + LogicalResult visitConstantExpr(AffineConstantExpr expr); + LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitModExpr(AffineBinaryOpExpr expr); + +private: + bool boundsSigned; + uint64_t bitWidth; + void inferBinOpRange( + AffineBinaryOpExpr expr, + const std::function)> + &opInference); + + /// Bounds that have been computed for subexpressions are memoized and reused. + llvm::DenseMap lb; + llvm::DenseMap ub; +}; + +#endif // MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 244db23925ab3..4fb2d42009367 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -860,11 +860,6 @@ def LinalgStructuredInterface /// `createFlatListOfOperandDims`. SmallVector createLoopRanges(OpBuilder &b, Location loc); - /// Compute the static loop sizes necessary to vectorize the computation. - /// This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); - /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible LogicalResult reifyResultShapes(OpBuilder &b, diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index 0c6cceb54b68a..5e60391e278eb 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -43,8 +43,9 @@ enum class UnaryOpKind { LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args); -Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, - Attribute element); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args); LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); LogicalResult div(PatternRewriter &rewriter, PDLResultList &results, diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 23cb3685795d4..ccdbb60a1fe9a 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1724,6 +1724,7 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { }]; let hasFolder = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; } @@ -1877,6 +1878,15 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, | signed 16 to float | int16 | float | | float 32 to float 64 | float32 | float64 | | float 64 to float 32 | float64 | float32 | + + AMD extensions: + | signed to unsigned | signed | unsigned| + | unsigned to signed | unsigned| signed | + | unsigned to float | unsigned| float | + - unsigned to signed integer and signed to unsigned integer: + wrap on overflow + - unsigned to float: + uses llvm's float to int conversion with TOSA rounding mode }]; let arguments = (ins diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h index 90fea1f68beb5..f49332eb54290 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h @@ -84,6 +84,18 @@ LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc, LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1, Value &input2); +Value getTosaConstShape(ImplicitLocOpBuilder &builder, + llvm::ArrayRef shape); + +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape); + +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType); + namespace { // Creates a TOSA operation and performs shape inference on the individual @@ -217,7 +229,8 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc, } // Apply an int32_t permutation to some input, that should be of the same -// size as perms. Perms should contain some permutation of 0 - perms.size() - 1. +// size as perms. Perms should contain some permutation of 0 - perms.size() +// - 1. template SmallVector applyTOSAPermutation(ArrayRef input, ArrayRef perms) { diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h index a93e74b449cee..28d00f1299f2f 100644 --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -110,6 +110,11 @@ class AffineExpr { /// floordiv, ceildiv, and mod is only allowed w.r.t constants. bool isPureAffine() const; + /// Returns true if this expression is monotonicically increasing with respect + /// to the AffineDimExprs, i.e. increasing the value of any AffineDimExpr will + /// never decrease the value of the result. + bool isMonotonicallyIncreasing() const; + /// Returns the greatest known integral divisor of this affine expression. The /// result is always positive. int64_t getLargestKnownDivisor() const; diff --git a/mlir/lib/Analysis/AffineExprBounds.cpp b/mlir/lib/Analysis/AffineExprBounds.cpp new file mode 100644 index 0000000000000..b71cfe4721323 --- /dev/null +++ b/mlir/lib/Analysis/AffineExprBounds.cpp @@ -0,0 +1,198 @@ +//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an analysis of affine expressions to compute their +// ranges (lower/upper bounds) in a given context. +// +//===----------------------------------------------------------------------===// +#include "mlir/Analysis/AffineExprBounds.h" + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" +#include "llvm/ADT/APInt.h" + +#include + +using namespace mlir; + +AffineExprBoundsVisitor::AffineExprBoundsVisitor( + ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, bool boundsSigned, + uint64_t bitWidth, MLIRContext *context) + : boundsSigned(boundsSigned), bitWidth(bitWidth) { + assert(constLowerBounds.size() == constUpperBounds.size()); + for (unsigned i = 0; i < constLowerBounds.size(); i++) { + if (constLowerBounds[i].has_value()) { + lb[getAffineDimExpr(i, context)] = constLowerBounds[i].value(); + } + if (constUpperBounds[i].has_value()) { + ub[getAffineDimExpr(i, context)] = constUpperBounds[i].value(); + } + } +} + +AffineExprBoundsVisitor::AffineExprBoundsVisitor( + ArrayRef> constLowerBounds, + ArrayRef> constUpperBounds, MLIRContext *context) + : boundsSigned(true), bitWidth(64) { + assert(constLowerBounds.size() == constUpperBounds.size()); + // Convert int64_ts to APInts. + for (unsigned i = 0; i < constLowerBounds.size(); i++) { + if (constLowerBounds[i].has_value()) { + lb[getAffineDimExpr(i, context)] = + APInt(64, constLowerBounds[i].value(), /*isSigned=*/true); + } + if (constUpperBounds[i].has_value()) { + ub[getAffineDimExpr(i, context)] = + APInt(64, constUpperBounds[i].value(), /*isSigned=*/true); + } + } +} + +std::optional AffineExprBoundsVisitor::getUpperBound(AffineExpr expr) { + // Use memoized bound if available. + auto i = ub.find(expr); + if (i != ub.end()) { + return i->second; + } + // Compute the bound otherwise. + if (failed(walkPostOrder(expr))) { + return std::nullopt; + } + return ub[expr]; +} + +std::optional AffineExprBoundsVisitor::getLowerBound(AffineExpr expr) { + // Use memoized bound if available. + auto i = lb.find(expr); + if (i != lb.end()) { + return i->second; + } + // Compute the bound otherwise. + if (failed(walkPostOrder(expr))) { + return std::nullopt; + } + return lb[expr]; +} + +std::optional +AffineExprBoundsVisitor::getIndexUpperBound(AffineExpr expr) { + std::optional apIntResult = getUpperBound(expr); + if (!apIntResult) + return std::nullopt; + + return apIntResult->getSExtValue(); +} + +std::optional +AffineExprBoundsVisitor::getIndexLowerBound(AffineExpr expr) { + std::optional apIntResult = getLowerBound(expr); + if (!apIntResult) + return std::nullopt; + + return apIntResult->getSExtValue(); +} + +ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) { + return ConstantIntRanges::range(lb, ub, boundsSigned); +} + +/// Wrapper around the intrange::infer* functions that infers the range of +/// binary operations on two ranges. +void AffineExprBoundsVisitor::inferBinOpRange( + AffineBinaryOpExpr expr, + const std::function)> + &opInference) { + ConstantIntRanges lhsRange = + getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned); + ConstantIntRanges rhsRange = + getRange(lb[expr.getRHS()], ub[expr.getRHS()], boundsSigned); + ConstantIntRanges result = opInference({lhsRange, rhsRange}); + + lb[expr] = (boundsSigned) ? result.smin() : result.umin(); + ub[expr] = (boundsSigned) ? result.smax() : result.umax(); +} + +// Visitor method overrides. +LogicalResult AffineExprBoundsVisitor::visitMulExpr(AffineBinaryOpExpr expr) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferMul(ranges); + }); + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitAddExpr(AffineBinaryOpExpr expr) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferAdd(ranges); + }); + return success(); +} +LogicalResult +AffineExprBoundsVisitor::visitCeilDivExpr(AffineBinaryOpExpr expr) { + inferBinOpRange( + expr, [boundsSigned = boundsSigned](ArrayRef ranges) { + if (boundsSigned) { + return intrange::inferCeilDivS(ranges); + } + return intrange::inferCeilDivU(ranges); + }); + return success(); +} +LogicalResult +AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) { + // There is no inferFloorDivU in the intrange library. We only offer + // computation of bounds for signed floordiv operations. + if (boundsSigned) { + inferBinOpRange(expr, [](ArrayRef ranges) { + return intrange::inferFloorDivS(ranges); + }); + return success(); + } + return failure(); +} +LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) { + // Only support integers >= 1 as RHS. + auto rhsConst = dyn_cast(expr.getRHS()); + if (!rhsConst || rhsConst.getValue() < 1) + return failure(); + + inferBinOpRange(expr, [boundsSigned = + boundsSigned](ArrayRef ranges) { + // Mod must return a value between 0 and N-1. + // Computing (N + (expr mod N)) mod N is guaranteed to yield a result in + // this range. + if (boundsSigned) { + auto rhs = ranges[1]; + auto lhs = ranges[0]; + return intrange::inferRemS( + {intrange::inferAdd({intrange::inferRemS({lhs, rhs}), rhs}), rhs}); + } + return intrange::inferRemU(ranges); + }); + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) { + if (lb.find(expr) == lb.end() || ub.find(expr) == ub.end()) { + return failure(); + } + return success(); +} +LogicalResult AffineExprBoundsVisitor::visitSymbolExpr(AffineSymbolExpr expr) { + return failure(); +} +LogicalResult +AffineExprBoundsVisitor::visitConstantExpr(AffineConstantExpr expr) { + APInt apIntVal = + APInt(bitWidth, static_cast(expr.getValue()), boundsSigned); + lb[expr] = apIntVal; + ub[expr] = apIntVal; + return success(); +} diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 609cb34309829..9462471a367a0 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -21,6 +21,7 @@ set(LLVM_OPTIONAL_SOURCES add_subdirectory(Presburger) add_mlir_library(MLIRAnalysis + AffineExprBounds.cpp AliasAnalysis.cpp CallGraph.cpp DataFlowFramework.cpp diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index d347ae916784b..3f39de4200359 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -84,7 +84,7 @@ class CmpFOpConversion : public OpConversionPattern { matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!isa(adaptor.getRhs().getType())) { + if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) { return rewriter.notifyMatchFailure(op.getLoc(), "cmpf currently only supported on " "floats, not tensors/vectors thereof"); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 57710b10283c1..f2b8598681a6a 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -768,17 +768,25 @@ struct OrderedPredicate { /// model. bool operator<(const OrderedPredicate &rhs) const { // Sort by: + // * not being a constraint. Rational: When writing constraints, it is + // sometimes assumed that checks for null or operation names are executed + // before the constraint. As there is no dependency between this + // operation, this is not always guaranteed, which can lead to bugs if the + // constraints is not checking inputs for null itself. By ordering + // constraints to the end, it is assured that implicit checks are nun + // before them // * higher first and secondary order sums // * lower depth // * lower position dependency // * lower predicate dependency // * lower tie breaking ID auto *rhsPos = rhs.position; - return std::make_tuple(primary, secondary, rhsPos->getOperationDepth(), + return std::make_tuple(!isa(question), primary, + secondary, rhsPos->getOperationDepth(), rhsPos->getKind(), rhs.question->getKind(), rhs.id) > - std::make_tuple(rhs.primary, rhs.secondary, - position->getOperationDepth(), position->getKind(), - question->getKind(), id); + std::make_tuple(!isa(rhs.question), rhs.primary, + rhs.secondary, position->getOperationDepth(), + position->getKind(), question->getKind(), id); } }; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index caf9cdb3a3eb4..52f2129c77ee3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1094,19 +1094,6 @@ SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { return res; } -SmallVector LinalgOp::computeStaticLoopSizes() { - AffineMap map = getLoopsToShapesMap(); - unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); - SmallVector allShapeSizes = createFlatListOfOperandStaticDims(); - SmallVector res(numDims, 0); - for (unsigned idx = 0; idx < numRes; ++idx) { - auto result = map.getResult(idx); - if (auto d = dyn_cast(result)) - res[d.getPosition()] = allShapeSizes[idx]; - } - return res; -} - /// Visitor to check if any of the given set of positions from AffineDimExprs /// are used within an AffineExpr. struct HasAffineDimExprVisitor diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 2e6079e1402e1..b53180b5cf7c3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -130,7 +130,7 @@ class FoldConstantBase : public OpInterfaceRewritePattern { return failure(); } - SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + SmallVector loopBounds = linalgOp.getStaticLoopRanges(); int64_t numElements = outputType.getNumElements(); // Use APInt/APFloat instead of Attribute here for constructing the output. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index a838b99c9dbb3..269272c10903c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -218,10 +218,15 @@ struct LinalgOpTilingInterface })); OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); + SmallVector allShapeSizes = + linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc()); + SmallVector sizeBounds = + mlir::affine::makeComposedFoldedMultiResultAffineApply( + b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes); SliceParameters sliceParams = computeSliceParameters( b, loc, outOperand->get(), sizes, linalgOp.getMatchingIndexingMap(outOperand), offsets, - /*ubs*/ {}, subShapeSizes, true); + /*ubs*/ sizeBounds, subShapeSizes, true); resultOffsets = sliceParams.offsets; resultSizes = sliceParams.sizes; return success(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 8e898904d87c2..dc2e3971d28bd 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -56,19 +56,24 @@ namespace { // `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0] // struct TileCheck : public AffineExprVisitor { - TileCheck(ArrayRef tileSizes, ArrayRef sizeBounds) - : tileSizes(tileSizes), sizeBounds(sizeBounds) {} + TileCheck(ArrayRef tileSizes, ArrayRef sizeBounds, + bool isMonotonicallyIncreasing) + : tileSizes(tileSizes), sizeBounds(sizeBounds), + isMonotonicallyIncreasing(isMonotonicallyIncreasing) {} void visitDimExpr(AffineDimExpr expr) { unsigned pos = expr.getPosition(); - // This dimension is tiled if the tile size is larger than zero and not - // equal to its domain size (if statically known). - std::optional tileSize = getConstantIntValue(tileSizes[pos]); - if (tileSize && !sizeBounds.empty()) { - std::optional sizeBound = getConstantIntValue(sizeBounds[pos]); - if (sizeBound && *sizeBound == *tileSize) { - return; + // If the expression is non monotonic, this dimension is tiled if the tile + // size is larger than zero and not equal to its domain size (if statically + // known). + if (!isMonotonicallyIncreasing) { + std::optional tileSize = getConstantIntValue(tileSizes[pos]); + if (tileSize && !sizeBounds.empty()) { + std::optional sizeBound = getConstantIntValue(sizeBounds[pos]); + if (sizeBound && *sizeBound == *tileSize) { + return; + } } } @@ -84,6 +89,7 @@ struct TileCheck : public AffineExprVisitor { bool isTiled = false; ArrayRef tileSizes; ArrayRef sizeBounds; + bool isMonotonicallyIncreasing; }; } // namespace @@ -92,7 +98,7 @@ static bool isTiled(AffineExpr expr, ArrayRef tileSizes, ArrayRef sizeBounds) { if (!expr) return false; - TileCheck t(tileSizes, sizeBounds); + TileCheck t(tileSizes, sizeBounds, expr.isMonotonicallyIncreasing()); t.visit(expr); return t.isTiled; } diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 9e4efbf7e71c0..bc238092c141d 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -38,13 +38,19 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, return success(); } -mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, - mlir::Attribute attr, - mlir::Attribute element) { - assert(isa(attr)); - auto values = cast(attr).getValue().vec(); - values.push_back(element); - return rewriter.getArrayAttr(values); +LogicalResult addElemToArrayAttr(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + + assert(args.size() == 2 && + "Expected two arguments, one ArrayAttr and one Attr"); + auto arrayAttr = cast(args[0].cast()); + auto attrElement = args[1].cast(); + llvm::SmallVector values(arrayAttr.getValue()); + values.push_back(attrElement); + + results.push_back(rewriter.getArrayAttr(values)); + return success(); } template @@ -344,11 +350,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { // See Parser::defineBuiltins() pdlPattern.registerRewriteFunction( "__builtin_addEntryToDictionaryAttr_rewrite", addEntryToDictionaryAttr); - pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr", - addElemToArrayAttr); pdlPattern.registerConstraintFunction( "__builtin_addEntryToDictionaryAttr_constraint", addEntryToDictionaryAttr); + + pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttrRewriter", + addElemToArrayAttr); + pdlPattern.registerConstraintFunction( + "__builtin_addElemToArrayAttrConstraint", addElemToArrayAttr); + pdlPattern.registerRewriteFunction("__builtin_mulRewrite", mul); pdlPattern.registerRewriteFunction("__builtin_divRewrite", div); pdlPattern.registerRewriteFunction("__builtin_modRewrite", mod); @@ -357,22 +367,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2); pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2); pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs); - pdlPattern.registerConstraintFunction("__builtin_mulConstraint", - mul); - pdlPattern.registerConstraintFunction("__builtin_divConstraint", - div); - pdlPattern.registerConstraintFunction("__builtin_modConstraint", - mod); - pdlPattern.registerConstraintFunction("__builtin_addConstraint", - add); - pdlPattern.registerConstraintFunction("__builtin_subConstraint", - sub); - pdlPattern.registerConstraintFunction("__builtin_log2Constraint", - log2); - pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", - exp2); - pdlPattern.registerConstraintFunction("__builtin_absConstraint", - abs); + pdlPattern.registerConstraintFunction("__builtin_mulConstraint", mul); + pdlPattern.registerConstraintFunction("__builtin_divConstraint", div); + pdlPattern.registerConstraintFunction("__builtin_modConstraint", mod); + pdlPattern.registerConstraintFunction("__builtin_addConstraint", add); + pdlPattern.registerConstraintFunction("__builtin_subConstraint", sub); + pdlPattern.registerConstraintFunction("__builtin_log2Constraint", log2); + pdlPattern.registerConstraintFunction("__builtin_exp2Constraint", exp2); + pdlPattern.registerConstraintFunction("__builtin_absConstraint", abs); pdlPattern.registerConstraintFunction("__builtin_equals", equals); } } // namespace mlir::pdl diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 7d21317d5e235..532237f083e89 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -60,9 +60,102 @@ struct ConcatOptimization : public OpRewritePattern { } }; +struct SelfConcatToTile : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ConcatOp concatOp, + PatternRewriter &rewriter) const override { + if (llvm::all_equal(concatOp->getUsers())) { + const auto concatUser = llvm::dyn_cast( + concatOp->getUses().begin()->getOwner()); + if (concatUser) { + // Try folding the concat into its consumer before rewriting it to a + // tile. + SmallVector replacementValues; + auto foldResult = rewriter.tryFold(concatUser, replacementValues); + if (foldResult.succeeded()) { + if (!replacementValues.empty()) { + rewriter.replaceOp(concatUser, replacementValues); + } + return success(); + } + } + } + + if (!llvm::all_equal(concatOp->getOperands())) { + return rewriter.notifyMatchFailure( + concatOp, "Requires all operands to be the same"); + } + const auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasRank()) { + return rewriter.notifyMatchFailure(concatOp, + "Requires concat to be ranked"); + } + SmallVector multiplies(concatType.getRank(), 1); + multiplies[concatOp.getAxis()] = concatOp->getNumOperands(); + auto constantShapeValue = + getTosaConstShape(rewriter, concatOp->getLoc(), multiplies); + auto tileOp = rewriter.createOrFold( + concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0), + constantShapeValue); + rewriter.replaceOp(concatOp, {tileOp}); + return success(); + } +}; + void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); +} + +struct FuseChainedTile : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TileOp op, + PatternRewriter &rewriter) const override { + SmallVector multiplies; + if (failed(op.getConstantMultiples(multiplies))) { + return rewriter.notifyMatchFailure(op, "Requires const multiplies"); + } + auto inputTile = op.getInput1().getDefiningOp(); + if (!inputTile) { + return rewriter.notifyMatchFailure(op, "Input is not a TileOp"); + } + if (!inputTile->hasOneUse()) { + return rewriter.notifyMatchFailure(op, + "Input tile should only have one use"); + } + SmallVector inputTileMultiples; + if (failed(inputTile.getConstantMultiples(inputTileMultiples))) { + return rewriter.notifyMatchFailure( + op, "Requires const multiplies on input tile"); + ; + } + + for (auto [idx, multiplier] : llvm::enumerate(inputTileMultiples)) { + multiplies[idx] *= multiplier; + } + auto constantShapeValue = getTosaConstShape( + rewriter, + rewriter.getFusedLoc( + {op.getMultiples().getLoc(), inputTile.getMultiples().getLoc()}), + multiplies); + + rewriter.modifyOpInPlace(op, [&]() { + op.setOperand(0, inputTile->getOperand(0)); + op.setOperand(1, constantShapeValue); + op.getOperation()->setLoc( + FusedLoc::get(getContext(), {inputTile->getLoc(), op.getLoc()})); + }); + + return success(); + } +}; + +void TileOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); } struct SqrtReciprocalOptimization : public OpRewritePattern { @@ -73,44 +166,54 @@ struct SqrtReciprocalOptimization : public OpRewritePattern { PatternRewriter &rewriter) const override { // Check that the PowOp has a single user if (!op->hasOneUse()) - return rewriter.notifyMatchFailure(op, "pow operator has more than one user"); + return rewriter.notifyMatchFailure(op, + "pow operator has more than one user"); - Operation* user = *op->user_begin(); + Operation *user = *op->user_begin(); // Check that this user is a reciprocal if (!isa(user)) - return rewriter.notifyMatchFailure(op, "expected a pow + reciprocal pattern"); + return rewriter.notifyMatchFailure(op, + "expected a pow + reciprocal pattern"); - // Check that the Pow op is an Sqrt - its second input should be the scale, 0.5 for Sqrt. - Operation* powScale = op.getInput2().getDefiningOp(); + // Check that the Pow op is an Sqrt - its second input should be the scale, + // 0.5 for Sqrt. + Operation *powScale = op.getInput2().getDefiningOp(); if (!powScale || !isa(powScale)) - return rewriter.notifyMatchFailure(op, "expected the pow to have a constant scale input"); + return rewriter.notifyMatchFailure( + op, "expected the pow to have a constant scale input"); - auto scale = cast(cast(powScale).getValue()); + auto scale = + cast(cast(powScale).getValue()); if (!scale.isSplat()) - return rewriter.notifyMatchFailure(op, "expected the pow scale to be a splat tensor"); + return rewriter.notifyMatchFailure( + op, "expected the pow scale to be a splat tensor"); float scaleValue = scale.getSplatValue().convertToFloat(); - if(scaleValue != 0.5) - return rewriter.notifyMatchFailure(op, "expected the pow to have a scale of 0.5 to be a sqrt"); + if (scaleValue != 0.5) + return rewriter.notifyMatchFailure( + op, "expected the pow to have a scale of 0.5 to be a sqrt"); auto inputType = cast(op.getOperand(0).getType()); auto outputType = cast(op.getType()); // If the operator needs tiling, fail to match - // An improvement for the future would be to generate a tile operator here instead + // An improvement for the future would be to generate a tile operator here + // instead if (inputType != outputType) - return rewriter.notifyMatchFailure(op, "input type and output type are different, tiling is not supported for this canonicalization"); + return rewriter.notifyMatchFailure( + op, "input type and output type are different, tiling is not " + "supported for this canonicalization"); auto rsqrtOp = rewriter.create( rewriter.getFusedLoc({op.getLoc(), user->getLoc()}), outputType, op.getInput1()); rewriter.replaceOp(user, rsqrtOp); - + return success(); } }; void PowOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { + MLIRContext *context) { results.add(context); } @@ -611,35 +714,121 @@ struct ConcatSliceOptimization : public OpRewritePattern { llvm::SmallVector sliceStart(sliceOp.getStart()); llvm::ArrayRef sliceSize = sliceOp.getSize(); - - // Validate slice on the concatenated axis. Slicing along this - // axis should span only one of the inputs to the concatenate - // operation. - std::optional replaceWithSlice; + llvm::SmallVector requiredConcatInputs; + int64_t processedOriginalConcatInputSize = 0; + int64_t droppedConcatInputSize = 0; for (auto input : inputs) { - auto inputType = dyn_cast(input.getType()); + const auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasStaticShape()) return rewriter.notifyMatchFailure( sliceOp, "concat input must be a static ranked tensor"); + if (processedOriginalConcatInputSize < + (sliceStart[axis] + sliceSize[axis]) && + (processedOriginalConcatInputSize + inputType.getDimSize(axis)) > + sliceStart[axis]) { + if (requiredConcatInputs.empty()) { + droppedConcatInputSize = processedOriginalConcatInputSize; + } + requiredConcatInputs.push_back(input); + } + processedOriginalConcatInputSize += inputType.getDimSize(axis); + } + if (requiredConcatInputs.size() == concatOp->getNumOperands()) { + return rewriter.notifyMatchFailure( + sliceOp, "Could not reduce number of inputs to preceding concat"); + } + if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) { + return rewriter.notifyMatchFailure( + sliceOp, + "Preceding concat must have a single use"); // Do not introduce new + // concats + } + if (requiredConcatInputs.empty()) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim in output"); + } + sliceStart[axis] -= droppedConcatInputSize; + auto newConcat = rewriter.create( + concatOp->getLoc(), requiredConcatInputs, axis); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newConcat, + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceSize)); + rewriter.replaceOp(sliceOp, newSlice); + return success(); + } +}; - if (sliceStart[axis] >= 0 && - (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); - break; +/// This patterns adjust the multipliers of a tile followed by a slice to only +/// tile as much data as it is required by the slice +struct TileSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput1(); + auto tileOp = sliceInput.getDefiningOp(); + if (!tileOp) + return rewriter.notifyMatchFailure(sliceOp, + "slice input must be tile operation"); + if (!tileOp->hasOneUse()) + return rewriter.notifyMatchFailure( + sliceOp, "preceding tile must have a single use"); // Do not insert + // additional tiles + + const auto tileOpInputType = + dyn_cast(tileOp->getOperand(0).getType()); + if (!tileOpInputType || !tileOpInputType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "input to preceding tile op must be a static ranked tensor"); + llvm::SmallVector requiredMultipliers; + llvm::SmallVector newTileStarts; + requiredMultipliers.reserve(tileOpInputType.getRank()); + newTileStarts.reserve(tileOpInputType.getRank()); + SmallVector tileMultiplies; + const LogicalResult tileHasConstantMultiplies = + tileOp.getConstantMultiples(tileMultiplies); + for (auto [axis, sliceStart, sliceSize] : + llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) { + if (sliceSize <= 0) { + return rewriter.notifyMatchFailure( + sliceOp, "degenerate slice with zero sized dim"); } - sliceStart[axis] -= inputType.getDimSize(axis); + const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis); + const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize; + const int64_t sliceSizeInFirstTile = + std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize); + assert(sliceSizeInFirstTile > 0); + const int64_t requiredMultiplierWithoutFirstTile = + llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize); + const int64_t requiredMultiplier = + requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0); + assert(failed(tileHasConstantMultiplies) || + requiredMultiplier <= tileMultiplies[axis]); + requiredMultipliers.push_back(requiredMultiplier); + newTileStarts.push_back(sliceOffsetInNewFirstTile); } - if (!replaceWithSlice) + if (succeeded(tileHasConstantMultiplies) && + requiredMultipliers == tileMultiplies) { return rewriter.notifyMatchFailure( - sliceOp, "corresponding concat input not found for slice"); + sliceOp, "could not reduce multipliers in preceding tile"); + } - rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + llvm::SmallVector newTileShape(tileOpInputType.getShape()); + for (auto [newShape, multiplier] : + llvm::zip_equal(newTileShape, requiredMultipliers)) { + newShape *= multiplier; + } + auto constantShapeValue = getTosaConstShape( + rewriter, tileOp.getMultiples().getLoc(), requiredMultipliers); + auto newTile = rewriter.create( + tileOp->getLoc(), tileOpInputType.clone(newTileShape), + tileOp->getOperand(0), constantShapeValue); + auto newSlice = rewriter.create( + sliceOp->getLoc(), sliceOp.getType(), newTile, + rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr()); + rewriter.replaceOp(sliceOp, newSlice); return success(); } }; @@ -647,6 +836,7 @@ struct ConcatSliceOptimization : public OpRewritePattern { void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(context); } struct MinToClampOptimization : public OpRewritePattern { @@ -1053,7 +1243,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { // cast-to-iN(cast-to-iM(x)) -> cast-to-iN(x) when N <= M if (auto cast = getInput().getDefiningOp()) { - auto intermediateElTy = cast.getType().getElementType().dyn_cast(); + auto intermediateElTy = + cast.getType().getElementType().dyn_cast(); auto finalElTy = getType().getElementType().dyn_cast(); if (intermediateElTy && finalElTy && intermediateElTy.getSignedness() == finalElTy.getSignedness() && @@ -1272,6 +1463,30 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { } OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { + const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) { + auto precedingSliceOp = getInput1().getDefiningOp(); + if (!precedingSliceOp) + return failure(); + const auto precedingSliceStart = precedingSliceOp.getStart(); + const auto thisSliceStart = getStart(); + SmallVector newSliceStart; + newSliceStart.reserve(precedingSliceStart.size()); + for (auto [startPreceding, startThis] : + llvm::zip_equal(precedingSliceStart, thisSliceStart)) { + newSliceStart.push_back(startPreceding + startThis); + } + setOperand(precedingSliceOp->getOperand(0)); + setStart(newSliceStart); + getOperation()->setLoc( + FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()})); + return success(); + }; + + // First try folding the preceding slice, this also works if the shapes are + // dynamic + if (succeeded(tryFoldWithPrecedingSlice(adaptor))) + return getResult(); + auto inputTy = llvm::dyn_cast(getInput1().getType()); auto outputTy = llvm::dyn_cast(getType()); @@ -1320,21 +1535,13 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { - if (getInput1().getType() == getType()) { - if (auto multiples = llvm::dyn_cast_if_present( - adaptor.getMultiples())) { - if (multiples.isSplat() && - multiples.getSplatValue().getSExtValue() == 1) - return getInput1(); - if (auto int_array_attr = - llvm::dyn_cast(multiples)) { - if (llvm::all_of(int_array_attr.getValues(), - [](APInt v) { return v.getSExtValue() == 1; })) - return getInput1(); - } - } + SmallVector multiples; + if (getInput1().getType() != getType() || + failed(getConstantMultiples(multiples)) || + !llvm::all_of(multiples, [](int64_t v) { return v == 1; })) { + return {}; } - return {}; + return getInput1(); } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { @@ -1402,7 +1609,7 @@ OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) { OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { /// Remove operands that have zero elements. bool changed = false; - for (size_t i = 0; i < getInput1().size(); ) { + for (size_t i = 0; i < getInput1().size();) { auto input = cast(getInput1()[i].getType()); // Ensure that we have at least one operand left. if (input.getDimSize(getAxis()) == 0 && getInput1().size() > 1) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index ea6f295ff2feb..4932ce87d57b7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise( // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); - for (auto val : toTransform.getValues()) { - auto transformedVal = toApply(val, targetType); - transformedValues.push_back(transformedVal); + if constexpr (std::is_same_v) { + for (auto val : toTransform.getValues()) { + auto transformedVal = + toApply(APSInt(val, toTransform.getElementType().isUnsignedInteger()), + targetType); + transformedValues.push_back(transformedVal); + } + } else { + for (auto val : toTransform.getValues()) { + auto transformedVal = toApply(val, targetType); + transformedValues.push_back(transformedVal); + } } // Make sure that the output tensor has the expected output type auto inShape = toTransform.getType(); auto outTy = inShape.cloneWith({}, targetType); - return DenseElementsAttr::get(outTy, transformedValues); + if constexpr (std::is_same_v) { + SmallVector transformedValuesAPInt; + transformedValuesAPInt.reserve(transformedValues.size()); + for (APSInt val : transformedValues) { + transformedValuesAPInt.emplace_back(val); + } + return DenseElementsAttr::get(outTy, transformedValuesAPInt); + } else { + return DenseElementsAttr::get(outTy, transformedValues); + } } template DenseElementsAttr applyElementWise( @@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { using TosaFoldConstantBase::TosaFoldConstantBase; - static APFloat convertIntToFloat(const APInt &toConvert, + static APFloat convertIntToFloat(const APSInt &toConvert, FloatType targetType) { APFloat res(targetType.getFloatSemantics()); - res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode); + res.convertFromAPInt(toConvert, toConvert.isSigned(), tosaRoundingMode); return res; } @@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { return converted; } - static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) { + static APSInt convertIntToInt(const APSInt &toConvert, + IntegerType targetType) { // Make sure to properly translate booleans if (targetType.getWidth() == 1) { - return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); - } - if (targetType.isUnsigned()) { - return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth()); + return APSInt(toConvert.isZero() ? APInt::getZero(1) + : APInt::getAllOnes(1)); } - return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth()); + return toConvert.extOrTrunc(targetType.getIntOrFloatBitWidth()); } static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location, @@ -981,11 +998,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { warnAboutNaNToIntCast(elements, tosaCast, rewriter); // Only fold splat tensors and those used only once to avoid duplicating - // them. + // them and increasing memory consumption. if (!inputTensor.hasOneUse() && !isa(elements)) { - return rewriter.notifyMatchFailure(tosaCast, - "Currently, casts will only be folded " - "if its input only has a single user"); + return rewriter.notifyMatchFailure( + tosaCast, "Currently, casts will only be folded " + "if its input only has a single user or is a splat value."); } // Report a match failure for unexpected types @@ -994,20 +1011,17 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { tosaCast, "Only casts from/to int/float are supported."); } - auto isUnsigned = [](Type toCheck) { - return isa(toCheck) && - cast(toCheck).isUnsigned(); - }; - auto typesToCheck = {toType, fromType}; - if (llvm::any_of(typesToCheck, isUnsigned)) { + // TOSA spec does not allow casts from/to unsigned, but we partially do, to + // enable the folding of lowered qdq nodes + if (isa(fromType) && isa(toType) && + cast(toType).isUnsigned()) { // TOSA casts currently don't support unsigned integers. - // To support them by here, one could use APSInt instead of APInts, - // however, this causes trouble with `getValues` which does not support - // APSInts currently. + // Casting float to unsigned int would need a decision about how to handle + // negative floats return rewriter.notifyMatchFailure( - tosaCast, "Cast folding from/to unsigned integers is not supported."); + tosaCast, + "Cast folding from float to unsigned integers is not supported."); } - DenseElementsAttr res; if (auto intOutTy = dyn_cast(toType)) { if (isa(fromType)) { @@ -1015,7 +1029,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { elements, &convertFloatToInt, intOutTy); } else { assert(isa(fromType)); - res = applyElementWise( + res = applyElementWise( elements, &convertIntToInt, intOutTy); } } else { @@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase { elements, &convertFloatToFloat, floatOutTy); } else { assert(isa(fromType)); - res = applyElementWise( + res = applyElementWise( elements, &convertIntToFloat, floatOutTy); } } diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 1f6e3b2ab8391..ad2363d5c4140 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" using namespace mlir; @@ -160,3 +161,73 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, return success(); } + +namespace { +SmallVector convertFromMlirShape(ArrayRef shape) { + return to_vector(llvm::map_range(shape, [](int64_t dim) { + return ShapedType::isDynamic(dim) ? -1 : dim; + })); +} +} // namespace + +Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder, + llvm::ArrayRef shape) { + auto attr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); + auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size()); + mlir::Operation *mlir_op = builder.create(type, attr); + return mlir_op->getResult(0); +} + +Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape) { + ImplicitLocOpBuilder builder(loc, rewriter); + return getTosaConstShape(builder, shape); +} + +// AMD: Picked from torch-mlir 12250739bfe85b702f9503cad45c2e535ea8eb18 +// Get accumulator type for TOSA convolution ops +LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, + TypeAttr &accType) { + auto inputElemTy = inputTy.getElementType(); + auto weightElemTy = weightTy.getElementType(); + auto outputElemTy = outputTy.getElementType(); + + auto quantTy = dyn_cast(inputElemTy); + if (quantTy) + inputElemTy = quantTy.getStorageType(); + + // Get TOSA conv ops acc type based on input, weight, and output types + // according to the spec: + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d + // + // For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the + // output type but does not offer any guarantee on the numerical precision + // since such cases will fail TOSA validation. + if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) || + (inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) || + (inputElemTy.isBF16() && weightElemTy.isBF16() && + outputElemTy.isBF16())) { + accType = mlir::TypeAttr::get(rewriter.getF32Type()); + } else if (inputElemTy.isInteger(8) && + (weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) && + outputElemTy.isInteger(32)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(32)); + } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && + outputElemTy.isInteger(48)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); + } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && + outputElemTy.isF16()) || + (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && + outputElemTy.isF16())) { + accType = mlir::TypeAttr::get(rewriter.getF16Type()); + } else { + accType = mlir::TypeAttr::get(outputElemTy); + } + + return success(); +} diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 59df0cd6833db..e74afe5bc8fbb 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -239,6 +239,42 @@ bool AffineExpr::isPureAffine() const { llvm_unreachable("Unknown AffineExpr"); } +static bool isNonNegativeConstant(AffineExpr expr) { + auto constant = dyn_cast(expr); + return constant && constant.getValue() >= 0; +} + +bool AffineExpr::isMonotonicallyIncreasing() const { + switch (getKind()) { + case AffineExprKind::SymbolId: + case AffineExprKind::DimId: + case AffineExprKind::Constant: + return true; + case AffineExprKind::Add: { + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + op.getRHS().isMonotonicallyIncreasing(); + } + case AffineExprKind::Mul: { + // One operand must be a non-negative constant. + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + op.getRHS().isMonotonicallyIncreasing() && + (isNonNegativeConstant(op.getLHS()) || + isNonNegativeConstant(op.getRHS())); + } + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: { + auto op = llvm::cast(*this); + return op.getLHS().isMonotonicallyIncreasing() && + isNonNegativeConstant(op.getRHS()); + } + case AffineExprKind::Mod: + return false; + } + llvm_unreachable("Unknown AffineExpr"); +} + // Returns the greatest known integral divisor of this affine expression. int64_t AffineExpr::getLargestKnownDivisor() const { AffineBinaryOpExpr binExpr(nullptr); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 4c9b9680ad841..8bcef9bcc9da5 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2843,7 +2843,8 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, SmallString<16> separator = StringRef(", "); if (printerFlags.getNewlineAfterAttrLimit() && - attrs.size() > *printerFlags.getNewlineAfterAttrLimit()) { + std::distance(filteredAttrs.begin(), filteredAttrs.end()) > + *printerFlags.getNewlineAfterAttrLimit()) { // Increase indent to match the visually match the "{ " below. // currentIndent += 2; diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 0250ecb0f7f28..aacb049f32b09 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -625,7 +625,8 @@ class Parser { struct { ast::UserRewriteDecl *addEntryToDictionaryAttr_Rewrite; ast::UserConstraintDecl *addEntryToDictionaryAttr_Constraint; - ast::UserRewriteDecl *addElemToArrayAttr; + ast::UserRewriteDecl *addElemToArrayAttrRewrite; + ast::UserConstraintDecl *addElemToArrayAttrConstraint; ast::UserRewriteDecl *mulRewrite; ast::UserRewriteDecl *divRewrite; ast::UserRewriteDecl *modRewrite; @@ -691,9 +692,13 @@ void Parser::declareBuiltins() { "__builtin_addEntryToDictionaryAttr_constraint", {"attr", "attrName", "attrEntry"}, /*returnsAttr=*/true); - builtins.addElemToArrayAttr = declareBuiltin( - "__builtin_addElemToArrayAttr", {"attr", "element"}, + builtins.addElemToArrayAttrRewrite = declareBuiltin( + "__builtin_addElemToArrayAttrRewriter", {"attr", "element"}, /*returnsAttr=*/true); + builtins.addElemToArrayAttrConstraint = + declareBuiltin( + "__builtin_addElemToArrayAttrConstraint", {"attr", "element"}, + /*returnsAttr=*/true); builtins.mulRewrite = declareBuiltin( "__builtin_mulRewrite", {"lhs", "rhs"}, true); builtins.divRewrite = declareBuiltin( @@ -2323,27 +2328,35 @@ FailureOr Parser::parseArrayAttrExpr() { consumeToken(Token::l_square); + ast::Decl *builtinFunction = builtins.addElemToArrayAttrRewrite; if (parserContext != ParserContext::Rewrite) - return emitError( - "Parsing of array attributes as constraint not supported!"); + builtinFunction = builtins.addElemToArrayAttrConstraint; - FailureOr arrayAttr = ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); + FailureOr arrayAttr = + ast::AttributeExpr::create(ctx, curToken.getLoc(), "[]"); if (failed(arrayAttr)) return failure(); + // No values inside the array + if (consumeIf(Token::r_square)) { + return arrayAttr; + } + do { FailureOr attr = parseExpr(); if (failed(attr)) return failure(); SmallVector arrayAttrArgs{*arrayAttr, *attr}; - auto elemToArrayCall = createBuiltinCall( - curToken.getLoc(), builtins.addElemToArrayAttr, arrayAttrArgs); + + auto elemToArrayCall = + createBuiltinCall(curToken.getLoc(), builtinFunction, arrayAttrArgs); if (failed(elemToArrayCall)) return failure(); // Uses the new array for the next element. arrayAttr = elemToArrayCall; + } while (consumeIf(Token::comma)); if (failed( @@ -2415,7 +2428,8 @@ FailureOr Parser::parseDictAttrExpr() { consumeToken(Token::l_brace); SMRange loc = curToken.getLoc(); - FailureOr dictAttrCall = ast::AttributeExpr::create(ctx, loc, "{}"); + FailureOr dictAttrCall = + ast::AttributeExpr::create(ctx, loc, "{}"); if (failed(dictAttrCall)) return failure(); diff --git a/mlir/test/Analysis/test-affine-expr-bounds.mlir b/mlir/test/Analysis/test-affine-expr-bounds.mlir new file mode 100644 index 0000000000000..03115760a29d0 --- /dev/null +++ b/mlir/test/Analysis/test-affine-expr-bounds.mlir @@ -0,0 +1,217 @@ +// RUN: mlir-opt -test-affine-expr-bounds --mlir-print-local-scope --allow-unregistered-dialect --verify-diagnostics %s | FileCheck %s + +func.func @test_compute_affine_expr_bounds() { + // Add + + // CHECK: "test.add"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.add"() {affine_map = affine_map<(d0) -> (d0 + 1)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.sub_const"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 1 + "test.sub_const"() {affine_map = affine_map<(d0) -> (d0 - 1)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.sub_dim"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 1 + "test.sub_dim"() {affine_map = affine_map<(d0) -> (1 - d0)>, lbs = [0], ubs = [2]} : () -> () + + // Mul + + // CHECK: "test.mul"() + // CHECK-SAME: expr_lb = 10 + // CHECK-SAME: expr_ub = 15 + "test.mul"() {affine_map = affine_map<(d0) -> (5 * d0)>, lbs = [2], ubs = [3]} : () -> () + + // CHECK: "test.mul_neg"() + // CHECK-SAME: expr_lb = -15 + // CHECK-SAME: expr_ub = -10 + "test.mul_neg"() {affine_map = affine_map<(d0) -> (-5 * d0)>, lbs = [2], ubs = [3]} : () -> () + + // Mod + + // CHECK: "test.mod_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 2 + "test.mod_basic"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.mod_wrap_around_by_range"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrap_around_by_range"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [0], ubs = [7]} : () -> () + + // CHECK: "test.mod_wrap_around_by_sum"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrap_around_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 3) mod 5)>, lbs = [0], ubs = [3]} : () -> () + + // CHECK: "test.mod_not_wrapping_around"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.mod_not_wrapping_around"() {affine_map = affine_map<(d0) -> (((d0 + 12) mod 11) mod 5)>, lbs = [0], ubs = [2]} : () -> () + + // CHECK: "test.mod_neg"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.mod_neg"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-4], ubs = [-2]} : () -> () + + // CHECK: "test.mod_wrapping_by_zero"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrapping_by_zero"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-2], ubs = [1]} : () -> () + + // FloorDiv + + // CHECK: "test.floordiv_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 1 + "test.floordiv_basic"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [0], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_not_stepping"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 1 + "test.floordiv_not_stepping"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [16], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_stepping_by_sum"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 2 + "test.floordiv_stepping_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 1) floordiv 16)>, lbs = [16], ubs = [31]} : () -> () + + // CHECK: "test.floordiv_neg_factor"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.floordiv_neg_factor"() {affine_map = affine_map<(d0) -> (d0 floordiv -8)>, lbs = [0], ubs = [8]} : () -> () + + // CHECK: "test.floordiv_neg_factor_not_stepping"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = -1 + "test.floordiv_neg_factor_not_stepping"() {affine_map = affine_map<(d0) -> (d0 floordiv -8)>, lbs = [1], ubs = [8]} : () -> () + + // CHECK: "test.floordiv_neg_range"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = -1 + "test.floordiv_neg_range"() {affine_map = affine_map<(d0) -> (d0 floordiv 8)>, lbs = [-8], ubs = [-1]} : () -> () + + // CeilDiv + + // CHECK: "test.ceildiv_basic"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 1 + "test.ceildiv_basic"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [0], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_not_stepping"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 1 + "test.ceildiv_not_stepping"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_stepping_by_sum"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 2 + "test.ceildiv_stepping_by_sum"() {affine_map = affine_map<(d0) -> ((d0 + 1) ceildiv 16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_neg_factor"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_factor"() {affine_map = affine_map<(d0) -> (d0 ceildiv -16)>, lbs = [1], ubs = [16]} : () -> () + + // CHECK: "test.ceildiv_neg_factor_not_stepping"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_factor_not_stepping"() {affine_map = affine_map<(d0) -> (d0 ceildiv -16)>, lbs = [0], ubs = [15]} : () -> () + + // CHECK: "test.ceildiv_neg_range"() + // CHECK-SAME: expr_lb = -1 + // CHECK-SAME: expr_ub = 0 + "test.ceildiv_neg_range"() {affine_map = affine_map<(d0) -> (d0 ceildiv 16)>, lbs = [-16], ubs = [-1]} : () -> () + + return +} + +// ----- + +func.func @test_bounds_unsigned() { + // CHECK: "test.unsigned"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 255 : ui8 + "test.unsigned"() {affine_map = affine_map<(d0) -> (d0)>, lbs = [0 : ui8], ubs = [255 : ui8]} : () -> () + + // CHECK: "test.unsigned_wrapping"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 255 : ui8 + "test.unsigned_wrapping"() {affine_map = affine_map<(d0) -> (d0 + 2)>, lbs = [253 : ui8], ubs = [255 : ui8]} : () -> () + + // CHECK: "test.unsigned_wrap_full"() + // CHECK-SAME: expr_lb = 0 : ui8 + // CHECK-SAME: expr_ub = 4 : ui8 + "test.unsigned_wrap_full"() {affine_map = affine_map<(d0) -> (d0 + 5)>, lbs = [251 : ui8], ubs = [255 : ui8]} : () -> () + + return +} + +// ----- + +func.func @test_unsigned_floordiv() { + // Result should be lb = 1, ub = 1, but we're missing an unsigned floordiv computation. + // expected-error @+1 {{Failed to compute bounds}} + "test.unsigned_floordiv"() {affine_map = affine_map<(d0) -> (d0 floordiv 128)>, lbs = [129 : ui8], ubs = [129 : ui8]} : () -> () + +} + +// ----- + +func.func @test_bounds_signed() { + // CHECK: "test.signed"() + // CHECK-SAME: expr_lb = -1 : i8 + // CHECK-SAME: expr_ub = 0 : i8 + "test.signed"() {affine_map = affine_map<(d0) -> (d0 floordiv 16)>, lbs = [-1 : i8], ubs = [0 : i8]} : () -> () + + // CHECK: "test.signed_wrapping"() + // CHECK-SAME: expr_lb = -128 : i8 + // CHECK-SAME: expr_ub = 127 : i8 + "test.signed_wrapping"() {affine_map = affine_map<(d0) -> (d0 + 3)>, lbs = [124 : i8], ubs = [127 : i8]} : () -> () + + // CHECK: "test.signed_wrap_full"() + // CHECK-SAME: expr_lb = -128 : i8 + // CHECK-SAME: expr_ub = -125 : i8 + "test.signed_wrap_full"() {affine_map = affine_map<(d0) -> (d0 + 4)>, lbs = [124 : i8], ubs = [127 : i8]} : () -> () + + return +} + +// ----- + +func.func @test_dynamic_lb_basic() { + // expected-error @+1 {{Failed to compute bounds}} + "test.dynamic_lb_basic"() {affine_map = affine_map<(d0) -> (d0)>, lbs = ["?"], ubs = [1]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_ub_basic() { + // expected-error @+1 {{Failed to compute bounds}} + "test.dynamic_ub_basic"() {affine_map = affine_map<(d0) -> (d0)>, lbs = [0], ubs = ["?"]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_lb_unused() { + // CHECK: "test.dynamic_lb_unused"() + // CHECK-SAME: expr_lb = 14 + // CHECK-SAME: expr_ub = 16 + "test.dynamic_lb_unused"() {affine_map = affine_map<(d0, d1) -> (d1 + 2)>, lbs = ["?", 12], ubs = [1, 14]} : () -> () + return +} + +// ----- + +func.func @test_dynamic_ub_unused() { + // CHECK: "test.dynamic_ub_unused"() + // CHECK-SAME: expr_lb = 14 + // CHECK-SAME: expr_ub = 16 + "test.dynamic_ub_unused"() {affine_map = affine_map<(d0, d1) -> (d1 + 2)>, lbs = [0, 12], ubs = ["?", 14]} : () -> () + return +} diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index aeb4c7233d1ff..42fadbc3a6a90 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -488,6 +488,27 @@ module @predicate_ordering { } } +// ----- + +// CHECK-LABEL: module @predicate_ordering_attr +module @predicate_ordering_attr { + // Check that the result is checked for null first, before applying the + // constraint. + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[RESULT:.*]] = pdl_interp.get_attribute "attr" of %[[ROOT]] + // CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]] + // CHECK: pdl_interp.apply_constraint "constraint" + + + pdl.pattern : benefit(1) { + %attr = attribute + pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute) + pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute) + %root = operation "foo.op" {"attr" = %attr} + rewrite %root with "rewriter" + } +} // ----- diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index 883eb732b2aa6..34125727201eb 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -177,9 +177,14 @@ func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> { %0 = tensor.dim %arg0, %c0 : tensor<7xf32> %empty = tensor.empty() : tensor<7xf32> - // CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32> - // CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) { - // CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[OUT:.*]] = tensor.empty() : tensor<7xf32> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index + // CHECK-DAG: %[[C7_1:.*]] = arith.constant 7 : index + // CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[C7]] step %[[C7_1]] iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) { + // CHECK: tensor.extract_slice %[[ARG0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32> + // CHECK: tensor.extract_slice %[[TC0]][%[[IV0]]] [7] [1] : tensor<7xf32> to tensor<7xf32> %generic = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0 mod 4)>, affine_map<(d0) -> (d0)>], @@ -199,3 +204,44 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#identity = affine_map<(d0, d1) -> (d0, d1)> +#identity1 = affine_map<(d0, d1) -> (d0 mod 3, d1)> + +// CHECK-LABEL: func @tile_monotonic_outer_dim +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x10xf32> +func.func @tile_monotonic_outer_dim(%in: tensor<4x10xf32>) -> tensor<4x10xf32> { + %empty = tensor.empty() : tensor<4x10xf32> + %1 = linalg.generic {indexing_maps = [#identity, #identity1], iterator_types = ["parallel", "parallel"]} + ins(%in : tensor<4x10xf32>) outs(%empty : tensor<4x10xf32>) { + ^bb1(%a: f32, %b: f32): + linalg.yield %a : f32 + } -> tensor<4x10xf32> + + // CHECK: %[[C4:.+]] = arith.constant 4 : index + // CHECK: %[[C4_1:.+]] = arith.constant 4 : index + // CHECK: %[[C5:.+]] = arith.constant 5 : index + // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[C4]] step %[[C4_1]] iter_args(%[[ARG1:.+]] = %[[OUT:.+]]) -> (tensor<4x10xf32>) { + // CHECK: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[ARG2:.+]] = %[[ARG1]]) -> (tensor<4x10xf32>) { + // CHECK: %[[INSLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> + // CHECK: %[[OUTSLICE:.+]] = tensor.extract_slice %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x10xf32> to tensor<4x5xf32> + // CHECK: %[[RES:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[INSLICE]] : tensor<4x5xf32>) outs(%[[OUTSLICE]] : tensor<4x5xf32>) { + // CHECK: ^bb0(%in: f32, %out: f32): + // CHECK: linalg.yield %in : f32 + // CHECK: } -> tensor<4x5xf32> + // CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[RES]] into %[[ARG2]][0, %[[IV1]]] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<4x10xf32> + // CHECK: scf.yield %[[INSERT_SLICE]] : tensor<4x10xf32> + // CHECK: } + + return %1 : tensor<4x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [4, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 1fe85b6411a7a..826811b0f2344 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -690,6 +690,81 @@ func.func @slice_nofold(%arg0: tensor) -> tensor { // ----- +// CHECK-LABEL: @slice_fuse +func.func @slice_fuse(%arg0: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x2xf32> +// CHECK: return [[VAR_0_]] : tensor<1x2xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<2x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<2x3xf32>) -> tensor<1x2xf32> + return %1 : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_step +func.func @slice_fuse_different_step(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start +func.func @slice_fuse_different_start(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<3x4xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<3x4xf32>) -> tensor<1x3xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<1x3xf32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_2 +func.func @slice_fuse_different_start_2(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: @slice_fuse_different_start_3 +func.func @slice_fuse_different_start_3(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<10x10xf32>) -> tensor<1x1xf32> +// CHECK: return [[VAR_0_]] : tensor<1x1xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<10x10xf32>) -> tensor<5x5xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<5x5xf32>) -> tensor<3x3xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<3x3xf32>) -> tensor<1x1xf32> + return %2 : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @slice_fuse_different_start_dynamic +func.func @slice_fuse_different_start_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array, start = array} : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return [[VAR_0_]] : tensor<*xf32> + %0 = tosa.slice %arg0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %1 = tosa.slice %0 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + %2 = tosa.slice %1 { size = array, start = array}: (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: @tile_fold func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0 @@ -700,6 +775,38 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // ----- +// CHECK-LABEL: func.func @tile_fuse_consecutive +func.func @tile_fuse_consecutive(%arg0: tensor<3x4xf32>) -> tensor<6x16xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<6x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[2, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<6x16xf32> +// CHECK: return [[VAR_1_]] : tensor<6x16xf32> + %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %cst_1 = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32> + %1 = tosa.tile %0, %cst_1: (tensor<3x8xf32>, !tosa.shape<2>) -> tensor<6x16xf32> + return %1 : tensor<6x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tile_no_fold_consecutive_multi_use +func.func @tile_no_fold_consecutive_multi_use(%arg0: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> (tensor<3x8xf32>, tensor<6x16xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.const_shape {value = dense<2> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_2_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.tile [[VAR_2_]], [[VAR_1_]] : (tensor<3x8xf32>, !tosa.shape<2>) -> tensor<6x16xf32> +// CHECK: return [[VAR_2_]], [[VAR_3_]] : tensor<3x8xf32>, tensor<6x16xf32> + %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %cst_1 = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst : (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32> + %1 = tosa.tile %0, %cst_1 : (tensor<3x8xf32>, !tosa.shape<2>) -> tensor<6x16xf32> + return %0, %1 : tensor<3x8xf32>, tensor<6x16xf32> +} + +// ----- + // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile @@ -814,6 +921,140 @@ func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : // ----- +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_start_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_start_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_end_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x4xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_end_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x2xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %1 : tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_all_overlap +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x4xf32> +func.func @canonicalize_concat_slice_partial_concat_all_overlap(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x4xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x4xf32> + return %1 : tensor<1x12x12x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_partial_concat_multi_use +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> +// CHECK: return [[VAR_0_]], [[VAR_1_]] : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +func.func @canonicalize_concat_slice_partial_concat_multi_use(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> (tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32>) { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x2xf32> + return %0, %1 : tensor<1x12x12x6xf32>, tensor<1x12x12x2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_concat_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_1_:%.+]]: tensor<1x12x12x2xf32>, [[PARAM_2_:%.+]]: tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]] {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> +// CHECK: [[VAR_1_:%.+]] = tosa.slice [[VAR_0_]] {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> +// CHECK: return [[VAR_1_]] : tensor<1x12x12x0xf32> +// CHECK: } +func.func @canonicalize_concat_slice_zero_dim(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<1x12x12x2xf32>, %arg2 : tensor<1x12x12x2xf32>) -> tensor<1x12x12x0xf32> { + %0 = tosa.concat %arg0, %arg1, %arg2 {axis = 3 : i32} : (tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>, tensor<1x12x12x2xf32>) -> tensor<1x12x12x6xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x12x12x6xf32>) -> tensor<1x12x12x0xf32> + return %1 : tensor<1x12x12x0xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 10, 2, 2, 3, 1]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<1x12x12x10x10x10xf32>, !tosa.shape<6>) -> tensor<1x120x24x20x30x10xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.slice [[VAR_1_]] {size = array, start = array} : (tensor<1x120x24x20x30x10xf32>) -> tensor<1x120x12x10x16x5xf32> +// CHECK: return [[VAR_2_]] : tensor<1x120x12x10x16x5xf32> +func.func @canonicalize_tile_slice(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x120x12x10x16x5xf32> { + %cst = tosa.const_shape { value = dense<[10, 10, 10, 10, 10, 10]> : tensor<6xindex> } : () -> !tosa.shape<6> + %0 = tosa.tile %arg0, %cst : (tensor<1x12x12x10x10x10xf32>, !tosa.shape<6>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x120x12x10x16x5xf32> + return %1 : tensor<1x120x12x10x16x5xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_fold +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x12x12x10x10x10xf32> +func.func @canonicalize_tile_slice_fold(%arg0 : tensor<1x12x12x10x10x10xf32>) -> tensor<1x12x12x10x10x10xf32> { + %cst = tosa.const_shape { value = dense<[10, 10, 10, 10, 10, 10]> : tensor<6xindex> } : () -> !tosa.shape<6> + %0 = tosa.tile %arg0, %cst : (tensor<1x12x12x10x10x10xf32>, !tosa.shape<6>) -> tensor<10x120x120x100x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100x100xf32>) -> tensor<1x12x12x10x10x10xf32> + return %1 : tensor<1x12x12x10x10x10xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_self_concat_slice +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { +// CHECK: return [[PARAM_0_]] : tensor<1x2x3x4xf32> +func.func @canonicalize_self_concat_slice(%arg0 : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %0 = tosa.concat %arg0, %arg0 {axis = 3 : i32} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x8xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<1x2x3x8xf32>) -> tensor<1x2x3x4xf32> + return %1 : tensor<1x2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_zero_dim +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<10> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<1x12x12x10x10xf32>, !tosa.shape<5>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.slice [[VAR_1_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> +// CHECK: return [[VAR_2_]] : tensor<1x0x12x10x16xf32> +func.func @canonicalize_tile_slice_zero_dim(%arg0 : tensor<1x12x12x10x10xf32>) -> tensor<1x0x12x10x16xf32> { + %cst = tosa.const_shape { value = dense<[10, 10, 10, 10, 10]> : tensor<5xindex> } : () -> !tosa.shape<5> + %0 = tosa.tile %arg0, %cst : (tensor<1x12x12x10x10xf32>, !tosa.shape<5>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x0x12x10x16xf32> + return %1 : tensor<1x0x12x10x16xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_tile_slice_multi_output +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<10> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<1x12x12x10x10xf32>, !tosa.shape<5>) -> tensor<10x120x120x100x100xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.slice [[VAR_1_]] {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> +// CHECK: return [[VAR_1_]], [[VAR_2_]] : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +func.func @canonicalize_tile_slice_multi_output(%arg0 : tensor<1x12x12x10x10xf32>) -> (tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32>) { + %cst = tosa.const_shape { value = dense<[10, 10, 10, 10, 10]> : tensor<5xindex> } : () -> !tosa.shape<5> + %0 = tosa.tile %arg0, %cst : (tensor<1x12x12x10x10xf32>, !tosa.shape<5>) -> tensor<10x120x120x100x100xf32> + %1 = tosa.slice %0 {size = array, start = array} : (tensor<10x120x120x100x100xf32>) -> tensor<1x12x12x10x16xf32> + return %0, %1 : tensor<10x120x120x100x100xf32>, tensor<1x12x12x10x16xf32> +} + +// ----- + // CHECK-LABEL: @canonicalize_optimize_sqrt_reciprocal func.func @canonicalize_optimize_sqrt_reciprocal(%arg0: tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> { // CHECK: %[[RSQRT:.*]] = tosa.rsqrt %arg{{.*}} : (tensor<1x5x1x1xf32>) -> tensor<1x5x1x1xf32> diff --git a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir index 75339eacb67d5..74421a6ab8ba9 100644 --- a/mlir/test/Dialect/Tosa/constant-cast-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-cast-opt.mlir @@ -71,6 +71,20 @@ func.func @cast_fold_f32_to_i8() -> tensor<5xi8> { return %1 : tensor<5xi8> } +// CHECK-LABEL: @cast_fold_f32_to_ui8 +// COM: Do not fold casts from floats to uint +func.func @cast_fold_f32_to_ui8() -> tensor<5xui8> { + // CHECK: tosa.const + // CHECK-NOT: tensor<5xui8> + // CHECK: tosa.cast + %0 = "tosa.const"() {value = + dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> : + tensor<5xf32> + } : () -> tensor<5xf32> + %1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xui8> + return %1 : tensor<5xui8> +} + // CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> { // Check if infinity and zero are translated properly. Don't expect any @@ -116,6 +130,71 @@ func.func @cast_fold_i32_to_i8() -> tensor<5xi8> { return %1 : tensor<5xi8> } +// CHECK-LABEL: @cast_fold_i8_to_ui8 +func.func @cast_fold_i8_to_ui8() -> tensor<3xui8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 251{{.*}}tensor<3xui8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, -5]> : + tensor<3xi8> + } : () -> tensor<3xi8> + %1 = "tosa.cast"(%0) : (tensor<3xi8>) -> tensor<3xui8> + return %1 : tensor<3xui8> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i8 +func.func @cast_fold_ui8_to_i8() -> tensor<3xi8> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, -6{{.*}}tensor<3xi8> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi8> + return %1 : tensor<3xi8> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i16 +func.func @cast_fold_ui8_to_i16() -> tensor<3xi16> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 250{{.*}}tensor<3xi16> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi16> + return %1 : tensor<3xi16> +} + +// CHECK-LABEL: @cast_fold_ui8_to_i1 +func.func @cast_fold_ui8_to_i1() -> tensor<3xi1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi1> + return %1 : tensor<3xi1> +} + +// CHECK-LABEL: @cast_fold_ui8_to_ui1 +func.func @cast_fold_ui8_to_ui1() -> tensor<3xui1> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xui1> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[4, 0, 250]> : + tensor<3xui8> + } : () -> tensor<3xui8> + %1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xui1> + return %1 : tensor<3xui1> +} + // CHECK-LABEL: @cast_fold_i16_to_i1 func.func @cast_fold_i16_to_i1() -> tensor<3xi1> { @@ -172,6 +251,19 @@ func.func @cast_fold_i32_to_f16() -> tensor<4xf16> { return %1 : tensor<4xf16> } +// CHECK-LABEL: @cast_fold_ui8_to_f32 +func.func @cast_fold_ui8_to_f32() -> tensor<4xf32> { + // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00, 1.000000e+00, 4.000000e+00, 2.550000e+02{{.*}}tensor<4xf32> + // CHECK-NOT: tosa.cast + // CHECK: return [[RES]] + %0 = "tosa.const"() {value = + dense<[0, 1, 4, 255]> : + tensor<4xui8> + } : () -> tensor<4xui8> + %1 = "tosa.cast"(%0) : (tensor<4xui8>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- // Casts from float to float diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index ec54f27346c8b..adc5875d943b0 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -5,10 +5,11 @@ func.func @single_concat(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { return %0 : tensor<1x2x7x7xf32> } -// CHECK-LABEL: func.func @single_concat( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: return %[[VAL_1]] : tensor<1x2x7x7xf32> +// CHECK-LABEL: func.func @single_concat +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 2, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<1x1x7x7xf32>, !tosa.shape<4>) -> tensor<1x2x7x7xf32> +// CHECK: return [[VAR_1_]] : tensor<1x2x7x7xf32> // CHECK: } // ----- @@ -19,11 +20,11 @@ func.func @concat_different_axis(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf return %1 : tensor<2x2x7x7xf32> } -// CHECK-LABEL: func.func @concat_different_axis( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> -// CHECK: return %[[VAL_2]] : tensor<2x2x7x7xf32> +// CHECK-LABEL: func.func @concat_different_axis +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[2, 2, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_0_]], [[VAR_0_]] : (tensor<1x1x7x7xf32>, !tosa.shape<4>) -> tensor<2x2x7x7xf32> +// CHECK: return [[VAR_1_]] : tensor<2x2x7x7xf32> // CHECK: } // ----- @@ -84,10 +85,10 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x return %2 : tensor<1x4x8x8xf32> } -// CHECK-LABEL: func.func @partially_foldable( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x8x8xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { -// CHECK: %[[VAL_2:.*]] = tosa.concat %[[VAL_1]], %[[VAL_1]] {axis = 2 : i32} : (tensor<1x2x4x8xf32>, tensor<1x2x4x8xf32>) -> tensor<1x2x8x8xf32> -// CHECK: %[[VAL_3:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]], %[[VAL_2]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> -// CHECK: return %[[VAL_3]] : tensor<1x4x8x8xf32> -// CHECK: } +// CHECK-LABEL: func.func @partially_foldable +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x8x8xf32>, [[PARAM_1_:%.+]]: tensor<1x2x4x8xf32>) -> tensor<1x4x8x8xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[1, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_1_]], [[VAR_0_]] : (tensor<1x2x4x8xf32>, !tosa.shape<4>) -> tensor<1x2x8x8xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return [[VAR_2_]] : tensor<1x4x8x8xf32> +// CHECK: } \ No newline at end of file diff --git a/mlir/test/IR/mlir-newline-after-attr.mlir b/mlir/test/IR/mlir-newline-after-attr.mlir index d35eac21a5152..047a9257d563c 100644 --- a/mlir/test/IR/mlir-newline-after-attr.mlir +++ b/mlir/test/IR/mlir-newline-after-attr.mlir @@ -29,3 +29,6 @@ // CHECK-NEXT: ], "test.op"() {foo.dense_attr = dense<1> : tensor<3xi32>, foo.second_attr = dense<2> : tensor<3xi32>, Operands = [{foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}, {foo.vect_attr_1_start = dense<0> : vector<3xindex>, foo.vect_attr_1_end = dense<0> : vector<3xindex>, foo.vect_attr_1_count = dense<1> : vector<3xindex>, foo.vect_attr_2_start = dense<0> : vector<3xindex>, foo.vect_attr_2_end = dense<0> : vector<3xindex>, foo.vect_attr_2_count = dense<1> : vector<3xindex>}]} : () -> () +// const_shape skips over shape attr when printing. Check that we do not insert unnecessary newlines +// CHECK{LITERAL}: shape.const_shape {foo.second_attr = dense<2> : tensor<3xi32>, foo.third_attr = dense<2> : tensor<3xi32>}[1, 1, 1] : tensor<3xindex> +"shape.const_shape"() {shape = dense<1> : tensor<3xindex>, foo.second_attr = dense<2> : tensor<3xi32>, foo.third_attr = dense<2> : tensor<3xi32>} : () -> (tensor<3xindex>) \ No newline at end of file diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt index 01297ad0a1148..9fe2ba0c610ef 100644 --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestAffineExpressionBounds.cpp TestAffineWalk.cpp TestBytecodeRoundtrip.cpp TestBuiltinAttributeInterfaces.cpp diff --git a/mlir/test/lib/IR/TestAffineExpressionBounds.cpp b/mlir/test/lib/IR/TestAffineExpressionBounds.cpp new file mode 100644 index 0000000000000..dc0b49d32130e --- /dev/null +++ b/mlir/test/lib/IR/TestAffineExpressionBounds.cpp @@ -0,0 +1,190 @@ +//===- TestAffineExpressionBounds.cpp - Test affine expression bounds --=====// +//----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AffineExprBounds.h" + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/FormatVariadic.h" + +#include "TestDialect.h" + +using namespace mlir; + +namespace { + +struct TestAffineExpressionBounds + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineExpressionBounds) + + StringRef getArgument() const final { return "test-affine-expr-bounds"; } + StringRef getDescription() const final { + return "Test simplify affine expression simplication"; + } + + FailureOr>> + getBound(Operation *op, StringRef boundType, bool *resultSigned, + uint64_t *resultWidth, bool optional = false) { + SmallVector> result; + + bool isSigned = false; + uint64_t width = 0; + + auto dict = op->getAttrDictionary(); + if (!dict) { + return op->emitError("No dictionary found"); + } + + auto bounds = dict.getNamed(boundType); + if (!bounds) { + if (!optional) { + return op->emitError(llvm::formatv("No {} attribute found", boundType)); + } + return failure(); + } + + auto boundsValue = cast(bounds->getValue()); + + for (auto v : boundsValue) { + if (auto value = dyn_cast(v)) { + if (width == 0) { + isSigned = (value.getType().isSignedInteger() || + value.getType().isSignlessInteger()); + width = value.getType().getIntOrFloatBitWidth(); + } else if (isSigned != (value.getType().isSignedInteger() || + value.getType().isSignlessInteger())) { + return op->emitError("Mixed signedness in bounds"); + } else if (width != value.getType().getIntOrFloatBitWidth()) { + return op->emitError("Mixed width in bounds"); + } + result.push_back(value.getValue()); + } else if (auto value = dyn_cast(v)) { + if (value.getValue() == "?") { + result.push_back(std::nullopt); + } else { + return op->emitError("Unknown string value found"); + } + } else { + return op->emitError("Non-integer or string value found in bounds"); + } + } + + *resultSigned = isSigned; + *resultWidth = width; + + return result; + } + + FailureOr getAffineExpr(Operation *op) { + auto dict = op->getAttrDictionary(); + if (!dict) { + return op->emitError("No dictionary found"); + } + auto affineMap = dict.getNamed("affine_map"); + if (!affineMap) { + return op->emitError("No affine_map attribute found"); + } + auto mapAttr = dyn_cast(affineMap->getValue()); + if (!mapAttr) { + return op->emitError("Invalid affine_map attribute found"); + } + + auto map = mapAttr.getAffineMap(); + if (map.getNumResults() != 1) { + return op->emitError("Invalid number of affine_map results"); + } + + return map.getResult(0); + } + + void runOnOperation() override { + func::FuncOp func = getOperation(); + IRRewriter rewriter(func.getContext()); + + func.walk([&](Operation *op) { + if (op->getDialect() != + op->getContext()->getLoadedDialect()) { + return; + } + + auto expr = getAffineExpr(op); + bool ubSigned, lbSigned; + uint64_t ubWidth, lbWidth; + auto ubs = getBound(op, "ubs", &ubSigned, &ubWidth); + auto lbs = getBound(op, "lbs", &lbSigned, &lbWidth); + + if (failed(expr) || failed(ubs) || failed(lbs)) { + return; + } + + if (ubs->size() != lbs->size()) { + op->emitError("Mismatched number of bounds"); + return; + } + if (ubWidth != lbWidth && + !((ubWidth == 0 && lbWidth > 0) || (ubWidth > 0 && lbWidth == 0))) { + op->emitError("Mismatched width in bounds"); + return; + } + bool signCheck = + !(ubWidth == 0 && lbWidth > 0) && !(ubWidth > 0 && lbWidth == 0); + if (signCheck && (ubSigned != lbSigned)) { + op->emitError("Mixed signedness in bounds"); + return; + } + + uint64_t width = (ubWidth == 0) ? lbWidth : ubWidth; + + AffineExprBoundsVisitor visitor(*lbs, *ubs, lbSigned, width, + &getContext()); + auto exprLB = visitor.getLowerBound(*expr); + auto exprUB = visitor.getUpperBound(*expr); + + if (!exprLB || !exprUB) { + op->emitError("Failed to compute bounds"); + return; + } + + auto namedAttrList = mlir::NamedAttrList{rewriter.getDictionaryAttr( + {rewriter.getNamedAttr( + "expr_lb", + IntegerAttr::get( + IntegerType::get( + &getContext(), width, + (lbSigned) ? IntegerType::SignednessSemantics::Signless + : IntegerType::SignednessSemantics::Unsigned), + *exprLB)), + rewriter.getNamedAttr( + "expr_ub", + IntegerAttr::get( + IntegerType::get( + &getContext(), width, + (ubSigned) ? IntegerType::SignednessSemantics::Signless + : IntegerType::SignednessSemantics::Unsigned), + *exprUB))})}; + op->setAttrs(namedAttrList); + }); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestAffineExpressionBounds() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 84cba9035123f..b876eecadfbce 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -218,7 +218,7 @@ Pattern RewriteMultipleEntriesDictionary { // CHECK: %[[VAL_4:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_5:.*]] = attribute = "test1" // CHECK: %[[VAL_6:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_3]], %[[VAL_4]], %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_2]], %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_2]], %[[VAL_6]] // CHECK: %[[VAL_8:.*]] = operation "test.success" {"some_array" = %[[VAL_7]]} // CHECK: replace %[[VAL_1]] with %[[VAL_8]] Pattern RewriteOneDictionaryArrayAttr { @@ -229,6 +229,43 @@ Pattern RewriteOneDictionaryArrayAttr { }; } +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintWithArrayAttr +// CHECK: %[[VAL_0:.*]] = attribute = "test1" +// CHECK: %[[VAL_1:.*]] = attribute = "test2" +// CHECK: %[[VAL_2:.*]] = attribute = [] +// CHECK: %[[VAL_3:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_2]], %[[VAL_0]] +// CHECK: %[[VAL_4:.*]] = apply_native_constraint "__builtin_addElemToArrayAttrConstraint"(%[[VAL_3]], %[[VAL_1]] +// CHECK: %[[VAL_5:.*]] = operation "test.op" +// CHECK: rewrite %[[VAL_5]] { +// CHECK: %[[VAL_6:.*]] = operation "test.success" {"some_array" = %[[VAL_4]]} +// CHECK: replace %[[VAL_5]] with %[[VAL_6]] + +Pattern ConstraintWithArrayAttr { + let attr1 = attr<"\"test1\"">; + let attr2 = attr<"\"test2\"">; + let array = [attr1, attr2]; + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + +// ----- + +// CHECK-LABEL: pdl.pattern @ConstraintNotMatchingArrayAttrInAttrType +// CHECK-NOT: apply_native_constraint "__builtin_addElemToArrayAttrConstraint" + + +Constraint I64Value(value: Value); +Pattern ConstraintNotMatchingArrayAttrInAttrType { + let root = op(arg: Value, arg2: Value, arg3: [Value, I64Value], arg); + replace root with arg; +} + + // ----- // CHECK-LABEL: pdl.pattern @RewriteMultiplyElementsArrayAttr @@ -240,8 +277,8 @@ Pattern RewriteOneDictionaryArrayAttr { // CHECK: %[[VAL_5:.*]] = attribute = "firstAttr" // CHECK: %[[VAL_6:.*]] = attribute = "test1" // CHECK: %[[VAL_7:.*]] = apply_native_rewrite "__builtin_addEntryToDictionaryAttr_rewrite"(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_3]], %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttr"(%[[VAL_8]], %[[VAL_2]] +// CHECK: %[[VAL_8:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_3]], %[[VAL_7]] +// CHECK: %[[VAL_9:.*]] = apply_native_rewrite "__builtin_addElemToArrayAttrRewriter"(%[[VAL_8]], %[[VAL_2]] // CHECK: %[[VAL_10:.*]] = operation "test.success" {"some_array" = %[[VAL_9]]} // CHECK: replace %[[VAL_1]] with %[[VAL_10]] Pattern RewriteMultiplyElementsArrayAttr { diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 34cf54fb7c23d..9d1218f124009 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -134,6 +134,22 @@ Pattern { // ----- +Pattern ConstraintArrayAttrWithAttrAndValue { + let root = op(arg: Value) -> (); + let attr1 = attr<"\"test1\"">; + let array = [attr1, arg]; + // CHECK: unable to convert expression of type `Value` to the expected type of `Attr` + let root = op -> (); + rewrite root with { + let newRoot = op() { some_array = array} -> (); + replace root with newRoot; + }; +} + +// ----- + + + //===----------------------------------------------------------------------===// // Range Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index fe3d8956b3dd7..8acdb7e9bba77 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -34,7 +34,7 @@ Pattern { // CHECK-LABEL: Module // CHECK: |-NamedAttributeDecl {{.*}} Name -// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttr> ResultType +// CHECK: `-UserRewriteDecl {{.*}} Name<__builtin_addElemToArrayAttrRewriter> ResultType // CHECK: `Arguments` // CHECK: CallExpr {{.*}} Type // CHECK: AttributeExpr {{.*}} Value<"[]"> @@ -87,6 +87,77 @@ Constraint getPopulatedDict() -> Attr { return dictionary; } + + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getEmtpyArray() -> Attr { + let array = []; + return array; +} + +// ----- + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-AttributeExpr {{.*}} Value<""attr1""> +//CHECK-NEXT:ReturnStmt {{.*}} + +Constraint getPopulateArray() -> Attr { + let array = ["attr1"]; + return array; +} + + +// ----- + + +// CHECK-LABEL: Module +// CHECK:LetStmt {{.*}} +//CHECK-NEXT:`-VariableDecl {{.*}} Name Type +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: | `-UserConstraintDecl {{.*}} Name<__builtin_addElemToArrayAttrConstraint> ResultType +// CHECK-DAG: `Arguments` +//CHECK-NEXT: |-AttributeExpr {{.*}} Value<"[]"> +//CHECK-NEXT: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK: `-CallExpr {{.*}} Type +//CHECK-NEXT: `-DeclRefExpr {{.*}} Type +//CHECK-NEXT: `-UserConstraintDecl {{.*}} Name ResultType +// CHECK-DAG: -ReturnStmt {{.*}} + +Constraint getA() -> Attr { + return "A"; +} + +Constraint getB() -> Attr { + return "B"; +} + +Constraint getPopulateArrayFromOtherConstraints() -> Attr { + let array = [getA(), getB()]; + return array; +} + + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 960f7037a1b61..4e369e489b19f 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -75,6 +75,7 @@ void registerInliner(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); +void registerTestAffineExpressionBounds(); void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); void registerTestArithEmulateWideIntPass(); @@ -214,6 +215,7 @@ void registerTestPasses() { mlir::test::registerMemRefBoundCheck(); mlir::test::registerPatternsTestPass(); mlir::test::registerSimpleParametricTilingPass(); + mlir::test::registerTestAffineExpressionBounds(); mlir::test::registerTestAffineLoopParametricTilingPass(); mlir::test::registerTestAliasAnalysisPass(); mlir::test::registerTestArithEmulateWideIntPass(); diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index 21a620e3b6675..113fc1ff8640f 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -66,13 +66,17 @@ TEST_F(BuiltinTest, addEntryToDictionaryAttr) { } TEST_F(BuiltinTest, addElemToArrayAttr) { + TestPDLResultList results(1); + auto dict = rewriter.getDictionaryAttr( rewriter.getNamedAttr("key", rewriter.getStringAttr("value"))); rewriter.getArrayAttr({}); auto arrAttr = rewriter.getArrayAttr({}); + EXPECT_TRUE(succeeded( + builtin::addElemToArrayAttr(rewriter, results, {arrAttr, dict}))); mlir::Attribute updatedArrAttr = - builtin::addElemToArrayAttr(rewriter, arrAttr, dict); + results.getResults().front().cast(); auto dictInsideArrAttr = cast(*cast(updatedArrAttr).begin()); @@ -617,7 +621,7 @@ TEST_F(BuiltinTest, log2) { cast(result.cast()).getValue().convertToFloat(), 2.0); } - + auto threeF16 = rewriter.getF16FloatAttr(3.0); // check correctness @@ -626,7 +630,8 @@ TEST_F(BuiltinTest, log2) { EXPECT_TRUE(builtin::log2(rewriter, results, {threeF16}).succeeded()); PDLValue result = results.getResults()[0]; - float resultVal = cast(result.cast()).getValue().convertToFloat(); + float resultVal = + cast(result.cast()).getValue().convertToFloat(); EXPECT_TRUE(resultVal > 1.58 && resultVal < 1.59); } }