Skip to content

Commit aeec945

Browse files
junfengd-nvjeanPerierjoker-eph
authored
[mlir][inliner] Add doClone and canHandleMultipleBlocks callbacks to Inliner Config (llvm#131226)
Current inliner disables inlining when the caller is in a region with single block trait, while the callee function contains multiple blocks. the SingleBlock trait is used in operations such as do/while loop, for example fir.do_loop, fir.iterate_while and fir.if. Typically, calls within loops are good candidates for inlining. However, functions with multiple blocks are also common. for example, any function with "if () then return" will result in multiple blocks in MLIR. This change gives the flexibility of a customized inliner to handle such cases. doClone: clones instructions and other information from the callee function into the caller function. . canHandleMultipleBlocks: checks if functions with multiple blocks can be inlined into a region with the SingleBlock trait. The default behavior of the inliner remains unchanged. --------- Co-authored-by: jeanPerier <jean.perier.polytechnique@gmail.com> Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
1 parent 70f5632 commit aeec945

File tree

9 files changed

+333
-103
lines changed

9 files changed

+333
-103
lines changed

mlir/include/mlir/Transforms/Inliner.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class InlinerConfig {
2727
public:
2828
using DefaultPipelineTy = std::function<void(OpPassManager &)>;
2929
using OpPipelinesTy = llvm::StringMap<OpPassManager>;
30+
using CloneCallbackSigTy = void(OpBuilder &builder, Region *src,
31+
Block *inlineBlock, Block *postInsertBlock,
32+
IRMapping &mapper,
33+
bool shouldCloneInlinedRegion);
34+
using CloneCallbackTy = std::function<CloneCallbackSigTy>;
3035

3136
InlinerConfig() = default;
3237
InlinerConfig(DefaultPipelineTy defaultPipeline,
@@ -39,13 +44,22 @@ class InlinerConfig {
3944
}
4045
const OpPipelinesTy &getOpPipelines() const { return opPipelines; }
4146
unsigned getMaxInliningIterations() const { return maxInliningIterations; }
47+
const CloneCallbackTy &getCloneCallback() const { return cloneCallback; }
48+
bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; }
49+
4250
void setDefaultPipeline(DefaultPipelineTy pipeline) {
4351
defaultPipeline = std::move(pipeline);
4452
}
4553
void setOpPipelines(OpPipelinesTy pipelines) {
4654
opPipelines = std::move(pipelines);
4755
}
4856
void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; }
57+
void setCloneCallback(CloneCallbackTy callback) {
58+
cloneCallback = std::move(callback);
59+
}
60+
void setCanHandleMultipleBlocks(bool value = true) {
61+
canHandleMultipleBlocks = value;
62+
}
4963

5064
private:
5165
/// An optional function that constructs an optimization pipeline for
@@ -60,6 +74,28 @@ class InlinerConfig {
6074
/// For SCC-based inlining algorithms, specifies maximum number of iterations
6175
/// when inlining within an SCC.
6276
unsigned maxInliningIterations{0};
77+
/// Callback for cloning operations during inlining
78+
CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src,
79+
Block *inlineBlock, Block *postInsertBlock,
80+
IRMapping &mapper,
81+
bool shouldCloneInlinedRegion) {
82+
// Check to see if the region is being cloned, or moved inline. In
83+
// either case, move the new blocks after the 'insertBlock' to improve
84+
// IR readability.
85+
Region *insertRegion = inlineBlock->getParent();
86+
if (shouldCloneInlinedRegion)
87+
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
88+
else
89+
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
90+
src->getBlocks(), src->begin(),
91+
src->end());
92+
};
93+
/// Determine if the inliner can inline a function containing multiple
94+
/// blocks into a region that requires a single block. By default, it is
95+
/// not allowed. If it is true, cloneCallback should perform the extra
96+
/// transformation. see the example in
97+
/// mlir/test/lib/Transforms/TestInliningCallback.cpp
98+
bool canHandleMultipleBlocks{false};
6399
};
64100

65101
/// This is an implementation of the inliner

mlir/include/mlir/Transforms/InliningUtils.h

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/Location.h"
1919
#include "mlir/IR/Region.h"
2020
#include "mlir/IR/ValueRange.h"
21+
#include "mlir/Transforms/Inliner.h"
2122
#include <optional>
2223

2324
namespace mlir {
@@ -253,43 +254,51 @@ class InlinerInterface
253254
/// provided, will be used to update the inlined operations' location
254255
/// information. 'shouldCloneInlinedRegion' corresponds to whether the source
255256
/// region should be cloned into the 'inlinePoint' or spliced directly.
256-
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
257-
Operation *inlinePoint, IRMapping &mapper,
258-
ValueRange resultsToReplace,
259-
TypeRange regionResultTypes,
260-
std::optional<Location> inlineLoc = std::nullopt,
261-
bool shouldCloneInlinedRegion = true);
262-
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
263-
Block *inlineBlock, Block::iterator inlinePoint,
264-
IRMapping &mapper, ValueRange resultsToReplace,
265-
TypeRange regionResultTypes,
266-
std::optional<Location> inlineLoc = std::nullopt,
267-
bool shouldCloneInlinedRegion = true);
257+
LogicalResult
258+
inlineRegion(InlinerInterface &interface,
259+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
260+
Region *src, Operation *inlinePoint, IRMapping &mapper,
261+
ValueRange resultsToReplace, TypeRange regionResultTypes,
262+
std::optional<Location> inlineLoc = std::nullopt,
263+
bool shouldCloneInlinedRegion = true);
264+
LogicalResult
265+
inlineRegion(InlinerInterface &interface,
266+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
267+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
268+
IRMapping &mapper, ValueRange resultsToReplace,
269+
TypeRange regionResultTypes,
270+
std::optional<Location> inlineLoc = std::nullopt,
271+
bool shouldCloneInlinedRegion = true);
268272

269273
/// This function is an overload of the above 'inlineRegion' that allows for
270274
/// providing the set of operands ('inlinedOperands') that should be used
271275
/// in-favor of the region arguments when inlining.
272-
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
273-
Operation *inlinePoint, ValueRange inlinedOperands,
274-
ValueRange resultsToReplace,
275-
std::optional<Location> inlineLoc = std::nullopt,
276-
bool shouldCloneInlinedRegion = true);
277-
LogicalResult inlineRegion(InlinerInterface &interface, Region *src,
278-
Block *inlineBlock, Block::iterator inlinePoint,
279-
ValueRange inlinedOperands,
280-
ValueRange resultsToReplace,
281-
std::optional<Location> inlineLoc = std::nullopt,
282-
bool shouldCloneInlinedRegion = true);
276+
LogicalResult
277+
inlineRegion(InlinerInterface &interface,
278+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
279+
Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
280+
ValueRange resultsToReplace,
281+
std::optional<Location> inlineLoc = std::nullopt,
282+
bool shouldCloneInlinedRegion = true);
283+
LogicalResult
284+
inlineRegion(InlinerInterface &interface,
285+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
286+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
287+
ValueRange inlinedOperands, ValueRange resultsToReplace,
288+
std::optional<Location> inlineLoc = std::nullopt,
289+
bool shouldCloneInlinedRegion = true);
283290

284291
/// This function inlines a given region, 'src', of a callable operation,
285292
/// 'callable', into the location defined by the given call operation. This
286293
/// function returns failure if inlining is not possible, success otherwise. On
287294
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
288295
/// corresponds to whether the source region should be cloned into the 'call' or
289296
/// spliced directly.
290-
LogicalResult inlineCall(InlinerInterface &interface, CallOpInterface call,
291-
CallableOpInterface callable, Region *src,
292-
bool shouldCloneInlinedRegion = true);
297+
LogicalResult
298+
inlineCall(InlinerInterface &interface,
299+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
300+
CallOpInterface call, CallableOpInterface callable, Region *src,
301+
bool shouldCloneInlinedRegion = true);
293302

294303
} // namespace mlir
295304

mlir/lib/Transforms/Utils/Inliner.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
652652
bool inlineInPlace = useList.hasOneUseAndDiscardable(it.targetNode);
653653

654654
LogicalResult inlineResult =
655-
inlineCall(inlinerIface, call,
655+
inlineCall(inlinerIface, inliner.config.getCloneCallback(), call,
656656
cast<CallableOpInterface>(targetRegion->getParentOp()),
657657
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
658658
if (failed(inlineResult)) {
@@ -730,19 +730,22 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) {
730730

731731
// Don't allow inlining if the callee has multiple blocks (unstructured
732732
// control flow) but we cannot be sure that the caller region supports that.
733-
bool calleeHasMultipleBlocks =
734-
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
735-
// If both parent ops have the same type, it is safe to inline. Otherwise,
736-
// decide based on whether the op has the SingleBlock trait or not.
737-
// Note: This check does currently not account for SizedRegion/MaxSizedRegion.
738-
auto callerRegionSupportsMultipleBlocks = [&]() {
739-
return callableRegion->getParentOp()->getName() ==
740-
resolvedCall.call->getParentOp()->getName() ||
741-
!resolvedCall.call->getParentOp()
742-
->mightHaveTrait<OpTrait::SingleBlock>();
743-
};
744-
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
745-
return false;
733+
if (!inliner.config.getCanHandleMultipleBlocks()) {
734+
bool calleeHasMultipleBlocks =
735+
llvm::hasNItemsOrMore(*callableRegion, /*N=*/2);
736+
// If both parent ops have the same type, it is safe to inline. Otherwise,
737+
// decide based on whether the op has the SingleBlock trait or not.
738+
// Note: This check does currently not account for
739+
// SizedRegion/MaxSizedRegion.
740+
auto callerRegionSupportsMultipleBlocks = [&]() {
741+
return callableRegion->getParentOp()->getName() ==
742+
resolvedCall.call->getParentOp()->getName() ||
743+
!resolvedCall.call->getParentOp()
744+
->mightHaveTrait<OpTrait::SingleBlock>();
745+
};
746+
if (calleeHasMultipleBlocks && !callerRegionSupportsMultipleBlocks())
747+
return false;
748+
}
746749

747750
if (!inliner.isProfitableToInline(resolvedCall))
748751
return false;

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Transforms/InliningUtils.h"
14+
#include "mlir/Transforms/Inliner.h"
1415

1516
#include "mlir/IR/Builders.h"
1617
#include "mlir/IR/IRMapping.h"
@@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
266267
}
267268

268269
static LogicalResult
269-
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
270-
Block::iterator inlinePoint, IRMapping &mapper,
271-
ValueRange resultsToReplace, TypeRange regionResultTypes,
272-
std::optional<Location> inlineLoc,
270+
inlineRegionImpl(InlinerInterface &interface,
271+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
272+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
273+
IRMapping &mapper, ValueRange resultsToReplace,
274+
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
273275
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
274276
assert(resultsToReplace.size() == regionResultTypes.size());
275277
// We expect the region to have at least one block.
@@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
296298
if (call && callable)
297299
handleArgumentImpl(interface, builder, call, callable, mapper);
298300

299-
// Check to see if the region is being cloned, or moved inline. In either
300-
// case, move the new blocks after the 'insertBlock' to improve IR
301-
// readability.
301+
// Clone the callee's source into the caller.
302302
Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint);
303-
if (shouldCloneInlinedRegion)
304-
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
305-
else
306-
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
307-
src->getBlocks(), src->begin(),
308-
src->end());
303+
cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper,
304+
shouldCloneInlinedRegion);
309305

310306
// Get the range of newly inserted blocks.
311307
auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()),
@@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
374370
}
375371

376372
static LogicalResult
377-
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
378-
Block::iterator inlinePoint, ValueRange inlinedOperands,
379-
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
373+
inlineRegionImpl(InlinerInterface &interface,
374+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
375+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
376+
ValueRange inlinedOperands, ValueRange resultsToReplace,
377+
std::optional<Location> inlineLoc,
380378
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
381379
// We expect the region to have at least one block.
382380
if (src->empty())
@@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
398396
}
399397

400398
// Call into the main region inliner function.
401-
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
402-
resultsToReplace, resultsToReplace.getTypes(),
403-
inlineLoc, shouldCloneInlinedRegion, call);
399+
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
400+
inlinePoint, mapper, resultsToReplace,
401+
resultsToReplace.getTypes(), inlineLoc,
402+
shouldCloneInlinedRegion, call);
404403
}
405404

406-
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
407-
Operation *inlinePoint, IRMapping &mapper,
408-
ValueRange resultsToReplace,
409-
TypeRange regionResultTypes,
410-
std::optional<Location> inlineLoc,
411-
bool shouldCloneInlinedRegion) {
412-
return inlineRegion(interface, src, inlinePoint->getBlock(),
405+
LogicalResult mlir::inlineRegion(
406+
InlinerInterface &interface,
407+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
408+
Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace,
409+
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
410+
bool shouldCloneInlinedRegion) {
411+
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
413412
++inlinePoint->getIterator(), mapper, resultsToReplace,
414413
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
415414
}
416-
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
417-
Block *inlineBlock,
418-
Block::iterator inlinePoint, IRMapping &mapper,
419-
ValueRange resultsToReplace,
420-
TypeRange regionResultTypes,
421-
std::optional<Location> inlineLoc,
422-
bool shouldCloneInlinedRegion) {
423-
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper,
424-
resultsToReplace, regionResultTypes, inlineLoc,
425-
shouldCloneInlinedRegion);
415+
416+
LogicalResult mlir::inlineRegion(
417+
InlinerInterface &interface,
418+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
419+
Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
420+
ValueRange resultsToReplace, TypeRange regionResultTypes,
421+
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
422+
return inlineRegionImpl(
423+
interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
424+
resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
426425
}
427426

428-
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
429-
Operation *inlinePoint,
430-
ValueRange inlinedOperands,
431-
ValueRange resultsToReplace,
432-
std::optional<Location> inlineLoc,
433-
bool shouldCloneInlinedRegion) {
434-
return inlineRegion(interface, src, inlinePoint->getBlock(),
427+
LogicalResult mlir::inlineRegion(
428+
InlinerInterface &interface,
429+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
430+
Operation *inlinePoint, ValueRange inlinedOperands,
431+
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
432+
bool shouldCloneInlinedRegion) {
433+
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
435434
++inlinePoint->getIterator(), inlinedOperands,
436435
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
437436
}
438-
LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src,
439-
Block *inlineBlock,
440-
Block::iterator inlinePoint,
441-
ValueRange inlinedOperands,
442-
ValueRange resultsToReplace,
443-
std::optional<Location> inlineLoc,
444-
bool shouldCloneInlinedRegion) {
445-
return inlineRegionImpl(interface, src, inlineBlock, inlinePoint,
446-
inlinedOperands, resultsToReplace, inlineLoc,
447-
shouldCloneInlinedRegion);
437+
438+
LogicalResult mlir::inlineRegion(
439+
InlinerInterface &interface,
440+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
441+
Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands,
442+
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
443+
bool shouldCloneInlinedRegion) {
444+
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
445+
inlinePoint, inlinedOperands, resultsToReplace,
446+
inlineLoc, shouldCloneInlinedRegion);
448447
}
449448

450449
/// Utility function used to generate a cast operation from the given interface,
@@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
475474
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
476475
/// corresponds to whether the source region should be cloned into the 'call' or
477476
/// spliced directly.
478-
LogicalResult mlir::inlineCall(InlinerInterface &interface,
479-
CallOpInterface call,
480-
CallableOpInterface callable, Region *src,
481-
bool shouldCloneInlinedRegion) {
477+
LogicalResult
478+
mlir::inlineCall(InlinerInterface &interface,
479+
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
480+
CallOpInterface call, CallableOpInterface callable,
481+
Region *src, bool shouldCloneInlinedRegion) {
482482
// We expect the region to have at least one block.
483483
if (src->empty())
484484
return failure();
@@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
552552
return cleanupState();
553553

554554
// Attempt to inline the call.
555-
if (failed(inlineRegionImpl(interface, src, call->getBlock(),
555+
if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(),
556556
++call->getIterator(), mapper, callResults,
557557
callableResultTypes, call.getLoc(),
558558
shouldCloneInlinedRegion, call)))

0 commit comments

Comments
 (0)