Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 155 additions & 20 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
}
};

struct SelfConcatToTile : public OpRewritePattern<tosa::ConcatOp> {
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::ConcatOp concatOp,
PatternRewriter &rewriter) const override {
if (llvm::all_equal(concatOp->getUsers())) {
const auto concatUser = llvm::dyn_cast<tosa::ConcatOp>(
concatOp->getUses().begin()->getOwner());
if (concatUser) {
// Try folding the concat into its consumer before rewriting it to a
// tile.
SmallVector<Value> replacementValues;
auto foldResult = rewriter.tryFold(concatUser, replacementValues);
if (foldResult.succeeded()) {
if (!replacementValues.empty()) {
rewriter.replaceOp(concatUser, replacementValues);
}
return success();
}
}
}

if (!llvm::all_equal(concatOp->getOperands())) {
return rewriter.notifyMatchFailure(
concatOp, "Requires all operands to be the same");
}
const auto concatType = dyn_cast<ShapedType>(concatOp.getType());
if (!concatType || !concatType.hasRank()) {
return rewriter.notifyMatchFailure(concatOp,
"Requires concat to be ranked");
}
SmallVector<int64_t> multiplies(concatType.getRank(), 1);
multiplies[concatOp.getAxis()] = concatOp->getNumOperands();
auto tileOp = rewriter.createOrFold<tosa::TileOp>(
concatOp->getLoc(), concatOp.getType(), concatOp->getOperand(0),
multiplies);
rewriter.replaceOp(concatOp, {tileOp});
return success();
}
};

void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatOptimization>(context);
results.add<SelfConcatToTile>(context);
}

struct SqrtReciprocalOptimization : public OpRewritePattern<tosa::PowOp> {
Expand Down Expand Up @@ -611,42 +653,120 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {

llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();

// Validate slice on the concatenated axis. Slicing along this
// axis should span only one of the inputs to the concatenate
// operation.
std::optional<Value> replaceWithSlice;
llvm::SmallVector<Value> requiredConcatInputs;
int64_t processedOriginalConcatInputSize = 0;
int64_t droppedConcatInputSize = 0;
for (auto input : inputs) {
auto inputType = dyn_cast<RankedTensorType>(input.getType());
const auto inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType || !inputType.hasStaticShape())
return rewriter.notifyMatchFailure(
sliceOp, "concat input must be a static ranked tensor");

if (sliceStart[axis] >= 0 &&
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
replaceWithSlice = rewriter
.create<tosa::SliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input,
rewriter.getDenseI64ArrayAttr(sliceStart),
rewriter.getDenseI64ArrayAttr(sliceSize))
.getResult();
break;
if (processedOriginalConcatInputSize <
(sliceStart[axis] + sliceSize[axis]) &&
(processedOriginalConcatInputSize + inputType.getDimSize(axis)) >
sliceStart[axis]) {
if (requiredConcatInputs.empty()) {
droppedConcatInputSize = processedOriginalConcatInputSize;
}
requiredConcatInputs.push_back(input);
}
sliceStart[axis] -= inputType.getDimSize(axis);
processedOriginalConcatInputSize += inputType.getDimSize(axis);
}
if (requiredConcatInputs.size() == concatOp->getNumOperands()) {
return rewriter.notifyMatchFailure(
sliceOp, "Could not reduce number of inputs to preceding concat");
}
if (requiredConcatInputs.size() != 1 && !concatOp->hasOneUse()) {
return rewriter.notifyMatchFailure(
sliceOp,
"Preceding concat must have a single use"); // Do not introduce new
// concats
}
if (requiredConcatInputs.empty()) {
return rewriter.notifyMatchFailure(
sliceOp, "degenerate slice with zero sized dim in output");
}
sliceStart[axis] -= droppedConcatInputSize;
auto newConcat = rewriter.create<tosa::ConcatOp>(
concatOp->getLoc(), requiredConcatInputs, axis);
auto newSlice = rewriter.create<tosa::SliceOp>(
sliceOp->getLoc(), sliceOp.getType(), newConcat,
rewriter.getDenseI64ArrayAttr(sliceStart),
rewriter.getDenseI64ArrayAttr(sliceSize));
rewriter.replaceOp(sliceOp, newSlice);
return success();
}
};

/// This patterns adjust the multipliers of a tile followed by a slice to only
/// tile as much data as it is required by the slice
struct TileSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
Value sliceInput = sliceOp.getInput1();
auto tileOp = sliceInput.getDefiningOp<tosa::TileOp>();
if (!tileOp)
return rewriter.notifyMatchFailure(sliceOp,
"slice input must be tile operation");
if (!tileOp->hasOneUse())
return rewriter.notifyMatchFailure(
sliceOp, "preceding tile must have a single use"); // Do not insert
// additional tiles

if (!replaceWithSlice)
const auto tileOpInputType =
dyn_cast<RankedTensorType>(tileOp->getOperand(0).getType());
if (!tileOpInputType || !tileOpInputType.hasStaticShape())
return rewriter.notifyMatchFailure(
sliceOp, "corresponding concat input not found for slice");
sliceOp, "input to preceding tile op must be a static ranked tensor");
llvm::SmallVector<int64_t> requiredMultipliers;
llvm::SmallVector<int64_t> newTileStarts;
requiredMultipliers.reserve(tileOpInputType.getRank());
newTileStarts.reserve(tileOpInputType.getRank());
for (auto [axis, sliceStart, sliceSize] :
llvm::enumerate(sliceOp.getStart(), sliceOp.getSize())) {
if (sliceSize <= 0) {
return rewriter.notifyMatchFailure(
sliceOp, "degenerate slice with zero sized dim");
}
const int64_t tileInputDimSize = tileOpInputType.getDimSize(axis);
const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
const int64_t sliceSizeInFirstTile =
std::min(tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
assert(sliceSizeInFirstTile > 0);
const int64_t requiredMultiplierWithoutFirstTile =
llvm::divideCeil(sliceSize - sliceSizeInFirstTile, tileInputDimSize);
const int64_t requiredMultiplier =
requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0);
assert(requiredMultiplier <= tileOp.getMultiples()[axis]);
requiredMultipliers.push_back(requiredMultiplier);
newTileStarts.push_back(sliceOffsetInNewFirstTile);
}
if (requiredMultipliers == tileOp.getMultiples())
return rewriter.notifyMatchFailure(
sliceOp, "could not reduce multipliers in preceding tile");

rewriter.replaceOp(sliceOp, replaceWithSlice.value());
llvm::SmallVector<int64_t> newTileShape(tileOpInputType.getShape());
for (auto [newShape, multiplier] :
llvm::zip_equal(newTileShape, requiredMultipliers)) {
newShape *= multiplier;
}
auto newTile = rewriter.create<tosa::TileOp>(
tileOp->getLoc(), tileOpInputType.clone(newTileShape),
tileOp->getOperand(0), requiredMultipliers);
auto newSlice = rewriter.create<tosa::SliceOp>(
sliceOp->getLoc(), sliceOp.getType(), newTile,
rewriter.getDenseI64ArrayAttr(newTileStarts), sliceOp.getSizeAttr());
rewriter.replaceOp(sliceOp, newSlice);
return success();
}
};

void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatSliceOptimization>(context);
results.add<TileSliceOptimization>(context);
}

struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {
Expand Down Expand Up @@ -1320,6 +1440,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();

if (auto inputTile = getInput1().getDefiningOp<TileOp>()) {
if (!inputTile->hasOneUse()) {
return {};
}
llvm::SmallVector<int64_t> newMultiplies{getMultiples()};
for (auto [idx, multiplier] : llvm::enumerate(inputTile.getMultiples())) {
newMultiplies[idx] *= multiplier;
}
setMultiples(newMultiplies);
setOperand(inputTile->getOperand(0));
getOperation()->setLoc(
FusedLoc::get(getContext(), {inputTile->getLoc(), getLoc()}));
return getResult();
}
return {};
}

Expand Down
Loading