Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void populateTosaDecomposeTransposeConv(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaDecomposeDepthwise(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantAddPatterns(MLIRContext *ctx,
RewritePatternSet &patterns);
void populateTosaFoldConstantCastPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
bool enableIntCastFolding);
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ DenseElementsAttr applyElementWise(
/// Apply the given transformation function on the elements of the given
/// tensors. If the input tensors do not match \p targetType, broadcasting is
/// applied.
template <class ElementType, class ResultType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);
const std::function<ResultType(const ElementType &, const ElementType &)>
&toApply);

/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
Expand Down Expand Up @@ -78,6 +80,19 @@ OffsetType getBroadcastedOffset(DimensionType desiredShape,
DimensionType toBeBroadcastedShape,
OffsetType offset);

/// Heuristic to decide when to replace a binary operation on constants with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever none of the inputs can be replaced but a new constant is inserted.
/// Hence, this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p binaryOp and the constant values of both operands,
/// \p valuesFirst and \p valuesSecond.
/// \returns Whether folding should be applied.
bool constantBinaryOpShouldBeFolded(TosaOp binaryOp,
DenseElementsAttr valuesFirst,
DenseElementsAttr valuesSecond);

/// Function to compute the reciprocal.
APFloat computeReciprocal(const APFloat &floatVal, FloatType floatTy);

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTosaTransforms
TosaDecomposeConv2D.cpp
TosaDecomposeDepthwise.cpp
TosaFoldCommon.cpp
TosaFoldConstantAdd.cpp
TosaFoldConstantCast.cpp
TosaFoldConstantPow.cpp
TosaFoldConstantReciprocal.cpp
Expand Down
46 changes: 42 additions & 4 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,14 @@ mlir::tosa::applyElementWise<APInt, APInt, IntegerType>(
const std::function<APInt(const APInt &, IntegerType)> &toApply,
IntegerType targetType);

template <class ElementType, class ResultType>
DenseElementsAttr mlir::tosa::applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply) {
const std::function<ResultType(const ElementType &, const ElementType &)>
&toApply) {
// Make sure to use the correct values in case broadcasting is required
SmallVector<APFloat> transformedValues;
SmallVector<ResultType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
auto targetSize = 1;
Expand All @@ -86,9 +88,9 @@ DenseElementsAttr mlir::tosa::applyElementWise(

// Apply the given function to each pair of values from the input tensors.
// Make sure to broadcast the offsets properly.
auto firstIt = first.getValues<APFloat>();
auto firstIt = first.getValues<ElementType>();
auto firstShape = first.getType().getShape();
auto secondIt = second.getValues<APFloat>();
auto secondIt = second.getValues<ElementType>();
auto secondShape = second.getType().getShape();
for (auto offset = 0; offset < targetSize; offset++) {
OffsetType offsetInTargetFirst =
Expand All @@ -105,6 +107,16 @@ DenseElementsAttr mlir::tosa::applyElementWise(
return newTensor;
}

template DenseElementsAttr mlir::tosa::applyElementWise<APFloat, APFloat>(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
const std::function<APFloat(const APFloat &, const APFloat &)> &toApply);

template DenseElementsAttr mlir::tosa::applyElementWise<APInt, APInt>(
const DenseElementsAttr &first, const DenseElementsAttr &second,
TensorType targetType,
const std::function<APInt(const APInt &, const APInt &)> &toApply);

LogicalResult
mlir::tosa::notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
Expand Down Expand Up @@ -205,6 +217,32 @@ OffsetType mlir::tosa::getBroadcastedOffset(DimensionType desiredShape,
return indexToOffset(toBeBroadcastedShape, indexBroadcasted);
}

bool mlir::tosa::constantBinaryOpShouldBeFolded(
TosaOp binaryOp, DenseElementsAttr valuesFirst,
DenseElementsAttr valuesSecond) {
assert(binaryOp->getNumOperands() == 2);
auto firstOp = binaryOp->getOperand(0);
auto secondOp = binaryOp->getOperand(1);

// If both tensors are splat, we don't care for the number of users
if (isa<SplatElementsAttr>(valuesFirst) &&
isa<SplatElementsAttr>(valuesSecond)) {
return true;
}

// If this is the only use of one of the tensors, it will be replaced an no
// additional memory is required.
if (firstOp.hasOneUse() || secondOp.hasOneUse()) {
return true;
}

// Fold it both inputs are equal and those are the only uses. Don't fold
// otherwise.
auto numUsers =
std::distance(firstOp.getUses().begin(), firstOp.getUses().end());
return firstOp == secondOp && numUsers == 2;
}

APFloat mlir::tosa::computeReciprocal(const APFloat &floatVal,
FloatType floatTy) {
auto recipAttr = FloatAttr::get(floatTy, 1.0);
Expand Down
98 changes: 98 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantAdd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- TosaFoldConstantAdd.cpp --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA Add operation on constant data
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/TosaFoldCommon.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include <mlir/Support/LogicalResult.h>

using namespace mlir;
using namespace mlir::tosa;

namespace {

struct TosaFoldConstantAdd : public OpRewritePattern<AddOp> {

using OpRewritePattern::OpRewritePattern;

static APInt computeIntAdd(const APInt &first, const APInt &second) {
return first.sadd_sat(second);
}

static APFloat computeFloatAdd(const APFloat &first, const APFloat &second) {
return first + second;
}

LogicalResult matchAndRewrite(AddOp addOp,
PatternRewriter &rewriter) const override {
auto leftOp = addOp.getInput1();
auto rightOp = addOp.getInput2();

auto resultType = addOp.getType();
auto lhsElemType = leftOp.getType().getElementType();
auto rhsElemType = rightOp.getType().getElementType();
if (lhsElemType != rhsElemType) {
return rewriter.notifyMatchFailure(
addOp, "Expected type of add arguments to match.");
}

// Check if both tensors are constant
auto rhsIsConstantCheck =
notifyIfNoTosaDenseConstantTensor(leftOp, addOp, rewriter);
if (failed(rhsIsConstantCheck)) {
return rhsIsConstantCheck;
}
auto lhsIsConstantCheck =
notifyIfNoTosaDenseConstantTensor(rightOp, addOp, rewriter);
if (failed(lhsIsConstantCheck)) {
return lhsIsConstantCheck;
}

// Extract the tensor values
DenseElementsAttr lhsValues;
matchPattern(leftOp, m_Constant(&lhsValues));

DenseElementsAttr rhsValues;
matchPattern(rightOp, m_Constant(&rhsValues));

if (!constantBinaryOpShouldBeFolded(addOp, lhsValues, rhsValues)) {
return rewriter.notifyMatchFailure(
addOp, "Currently, adds will only be folded if this requires only "
"little additional memory usage.");
}

DenseElementsAttr newTensor;
if (isa<IntegerType>(lhsElemType)) {
assert(isa<IntegerType>(rhsElemType) &&
isa<IntegerType>(resultType.getElementType()));
newTensor = applyElementWise<APInt, APInt>(lhsValues, rhsValues,
resultType, &computeIntAdd);
} else {
assert(isa<FloatType>(lhsElemType) && isa<FloatType>(rhsElemType) &&
isa<FloatType>(resultType.getElementType()));
newTensor = applyElementWise<APFloat, APFloat>(
lhsValues, rhsValues, resultType, &computeFloatAdd);
}
rewriter.replaceOpWithNewOp<ConstOp>(addOp, newTensor.getType(), newTensor);

return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantAddPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantAdd>(ctx);
}
23 changes: 11 additions & 12 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantPow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ struct TosaFoldConstantPow : public OpRewritePattern<PowOp> {
auto baseOp = powOp.getInput1();
auto expOp = powOp.getInput2();

if (baseOp.getType().getElementType() != expOp.getType().getElementType()) {
return rewriter.notifyMatchFailure(
powOp, "Expected type of pow arguments to match.");
}

// Check if both tensors are constant
auto baseIsConstCheck =
notifyIfNotConstantFloatTosaTensor(baseOp, powOp, rewriter);
Expand All @@ -82,20 +87,14 @@ struct TosaFoldConstantPow : public OpRewritePattern<PowOp> {
DenseElementsAttr expValues;
matchPattern(expOp, m_Constant(&expValues));

// If both tensors are splat, we don't care for the number of users
if (!isa<SplatElementsAttr>(baseValues) ||
!isa<SplatElementsAttr>(expValues)) {
// Make sure that at least one of the constant input tensors can be
// replaced (i.e. only has a single user)
if (!baseOp.hasOneUse() && !expOp.hasOneUse()) {
return rewriter.notifyMatchFailure(
powOp, "Currently, pows will only be folded if at least one input "
"tensor only has a single user");
}
if (!constantBinaryOpShouldBeFolded(powOp, baseValues, expValues)) {
return rewriter.notifyMatchFailure(
powOp, "Currently, pows will only be folded if this requires only "
"little additional memory usage.");
}

auto newTensor =
applyElementWise(baseValues, expValues, powOp.getType(), &computePower);
auto newTensor = applyElementWise<APFloat, APFloat>(
baseValues, expValues, powOp.getType(), &computePower);
rewriter.replaceOpWithNewOp<ConstOp>(powOp, newTensor.getType(), newTensor);

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ struct TosaLayerwiseConstantFoldPass
RewritePatternSet patterns(ctx);
auto func = getOperation();

mlir::tosa::populateTosaFoldConstantAddPatterns(ctx, patterns);
mlir::tosa::populateTosaFoldConstantCastPatterns(ctx, patterns,
enableIntCastFolding);
mlir::tosa::populateTosaFoldConstantPowPatterns(ctx, patterns);
Expand Down
Loading