Skip to content

Commit 3e08dcd

Browse files
committed
[mlir][inliner] Move callback types from InlinerConfig -> InlinerInterface. NFC.
The proper layering here is that Inliner depends on InlinerUtils, and not the other way round. Maybe it's time to give InliningUtils a less terrible file name.
1 parent 0a17427 commit 3e08dcd

File tree

4 files changed

+64
-63
lines changed

4 files changed

+64
-63
lines changed

mlir/include/mlir/Transforms/Inliner.h

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CallInterfaces.h"
1818
#include "mlir/Pass/AnalysisManager.h"
1919
#include "mlir/Pass/PassManager.h"
20+
#include "mlir/Transforms/InliningUtils.h"
2021
#include "llvm/ADT/StringMap.h"
2122

2223
namespace mlir {
@@ -27,11 +28,6 @@ class InlinerConfig {
2728
public:
2829
using DefaultPipelineTy = std::function<void(OpPassManager &)>;
2930
using OpPipelinesTy = llvm::StringMap<OpPassManager>;
30-
using CloneCallbackSigTy = void(OpBuilder &builder, Region *src,
31-
Block *inlineBlock, Block *postInsertBlock,
32-
IRMapping &mapper,
33-
bool shouldCloneInlinedRegion);
34-
using CloneCallbackTy = std::function<CloneCallbackSigTy>;
3531

3632
InlinerConfig() = default;
3733
InlinerConfig(DefaultPipelineTy defaultPipeline,
@@ -44,7 +40,9 @@ class InlinerConfig {
4440
}
4541
const OpPipelinesTy &getOpPipelines() const { return opPipelines; }
4642
unsigned getMaxInliningIterations() const { return maxInliningIterations; }
47-
const CloneCallbackTy &getCloneCallback() const { return cloneCallback; }
43+
const InlinerInterface::CloneCallbackTy &getCloneCallback() const {
44+
return cloneCallback;
45+
}
4846
bool getCanHandleMultipleBlocks() const { return canHandleMultipleBlocks; }
4947

5048
void setDefaultPipeline(DefaultPipelineTy pipeline) {
@@ -54,7 +52,7 @@ class InlinerConfig {
5452
opPipelines = std::move(pipelines);
5553
}
5654
void setMaxInliningIterations(unsigned max) { maxInliningIterations = max; }
57-
void setCloneCallback(CloneCallbackTy callback) {
55+
void setCloneCallback(InlinerInterface::CloneCallbackTy callback) {
5856
cloneCallback = std::move(callback);
5957
}
6058
void setCanHandleMultipleBlocks(bool value = true) {
@@ -75,21 +73,21 @@ class InlinerConfig {
7573
/// when inlining within an SCC.
7674
unsigned maxInliningIterations{0};
7775
/// Callback for cloning operations during inlining
78-
CloneCallbackTy cloneCallback = [](OpBuilder &builder, Region *src,
79-
Block *inlineBlock, Block *postInsertBlock,
80-
IRMapping &mapper,
81-
bool shouldCloneInlinedRegion) {
82-
// Check to see if the region is being cloned, or moved inline. In
83-
// either case, move the new blocks after the 'insertBlock' to improve
84-
// IR readability.
85-
Region *insertRegion = inlineBlock->getParent();
86-
if (shouldCloneInlinedRegion)
87-
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
88-
else
89-
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
90-
src->getBlocks(), src->begin(),
91-
src->end());
92-
};
76+
InlinerInterface::CloneCallbackTy cloneCallback =
77+
[](OpBuilder &builder, Region *src, Block *inlineBlock,
78+
Block *postInsertBlock, IRMapping &mapper,
79+
bool shouldCloneInlinedRegion) {
80+
// Check to see if the region is being cloned, or moved inline. In
81+
// either case, move the new blocks after the 'insertBlock' to improve
82+
// IR readability.
83+
Region *insertRegion = inlineBlock->getParent();
84+
if (shouldCloneInlinedRegion)
85+
src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper);
86+
else
87+
insertRegion->getBlocks().splice(postInsertBlock->getIterator(),
88+
src->getBlocks(), src->begin(),
89+
src->end());
90+
};
9391
/// Determine if the inliner can inline a function containing multiple
9492
/// blocks into a region that requires a single block. By default, it is
9593
/// not allowed. If it is true, cloneCallback should perform the extra

mlir/include/mlir/Transforms/InliningUtils.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "mlir/IR/Location.h"
1919
#include "mlir/IR/Region.h"
2020
#include "mlir/IR/ValueRange.h"
21-
#include "mlir/Transforms/Inliner.h"
2221
#include <optional>
2322

2423
namespace mlir {
@@ -192,6 +191,12 @@ class DialectInlinerInterface
192191
class InlinerInterface
193192
: public DialectInterfaceCollection<DialectInlinerInterface> {
194193
public:
194+
using CloneCallbackSigTy = void(OpBuilder &builder, Region *src,
195+
Block *inlineBlock, Block *postInsertBlock,
196+
IRMapping &mapper,
197+
bool shouldCloneInlinedRegion);
198+
using CloneCallbackTy = std::function<CloneCallbackSigTy>;
199+
195200
using Base::Base;
196201

197202
/// Process a set of blocks that have been inlined. This callback is invoked
@@ -256,14 +261,14 @@ class InlinerInterface
256261
/// region should be cloned into the 'inlinePoint' or spliced directly.
257262
LogicalResult
258263
inlineRegion(InlinerInterface &interface,
259-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
264+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
260265
Region *src, Operation *inlinePoint, IRMapping &mapper,
261266
ValueRange resultsToReplace, TypeRange regionResultTypes,
262267
std::optional<Location> inlineLoc = std::nullopt,
263268
bool shouldCloneInlinedRegion = true);
264269
LogicalResult
265270
inlineRegion(InlinerInterface &interface,
266-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
271+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
267272
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
268273
IRMapping &mapper, ValueRange resultsToReplace,
269274
TypeRange regionResultTypes,
@@ -275,14 +280,14 @@ inlineRegion(InlinerInterface &interface,
275280
/// in-favor of the region arguments when inlining.
276281
LogicalResult
277282
inlineRegion(InlinerInterface &interface,
278-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
283+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
279284
Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
280285
ValueRange resultsToReplace,
281286
std::optional<Location> inlineLoc = std::nullopt,
282287
bool shouldCloneInlinedRegion = true);
283288
LogicalResult
284289
inlineRegion(InlinerInterface &interface,
285-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
290+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
286291
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
287292
ValueRange inlinedOperands, ValueRange resultsToReplace,
288293
std::optional<Location> inlineLoc = std::nullopt,
@@ -296,7 +301,7 @@ inlineRegion(InlinerInterface &interface,
296301
/// spliced directly.
297302
LogicalResult
298303
inlineCall(InlinerInterface &interface,
299-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
304+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
300305
CallOpInterface call, CallableOpInterface callable, Region *src,
301306
bool shouldCloneInlinedRegion = true);
302307

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Transforms/InliningUtils.h"
14-
#include "mlir/Transforms/Inliner.h"
1514

1615
#include "mlir/IR/Builders.h"
1716
#include "mlir/IR/IRMapping.h"
@@ -266,13 +265,13 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
266265
}
267266
}
268267

269-
static LogicalResult
270-
inlineRegionImpl(InlinerInterface &interface,
271-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
272-
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
273-
IRMapping &mapper, ValueRange resultsToReplace,
274-
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
275-
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
268+
static LogicalResult inlineRegionImpl(
269+
InlinerInterface &interface,
270+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
271+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
272+
IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes,
273+
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
274+
CallOpInterface call = {}) {
276275
assert(resultsToReplace.size() == regionResultTypes.size());
277276
// We expect the region to have at least one block.
278277
if (src->empty())
@@ -369,13 +368,13 @@ inlineRegionImpl(InlinerInterface &interface,
369368
return success();
370369
}
371370

372-
static LogicalResult
373-
inlineRegionImpl(InlinerInterface &interface,
374-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
375-
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
376-
ValueRange inlinedOperands, ValueRange resultsToReplace,
377-
std::optional<Location> inlineLoc,
378-
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
371+
static LogicalResult inlineRegionImpl(
372+
InlinerInterface &interface,
373+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
374+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
375+
ValueRange inlinedOperands, ValueRange resultsToReplace,
376+
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion,
377+
CallOpInterface call = {}) {
379378
// We expect the region to have at least one block.
380379
if (src->empty())
381380
return failure();
@@ -404,20 +403,20 @@ inlineRegionImpl(InlinerInterface &interface,
404403

405404
LogicalResult mlir::inlineRegion(
406405
InlinerInterface &interface,
407-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
408-
Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace,
409-
TypeRange regionResultTypes, std::optional<Location> inlineLoc,
410-
bool shouldCloneInlinedRegion) {
406+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
407+
Region *src, Operation *inlinePoint, IRMapping &mapper,
408+
ValueRange resultsToReplace, TypeRange regionResultTypes,
409+
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
411410
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
412411
++inlinePoint->getIterator(), mapper, resultsToReplace,
413412
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
414413
}
415414

416415
LogicalResult mlir::inlineRegion(
417416
InlinerInterface &interface,
418-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
419-
Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper,
420-
ValueRange resultsToReplace, TypeRange regionResultTypes,
417+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
418+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
419+
IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes,
421420
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
422421
return inlineRegionImpl(
423422
interface, cloneCallback, src, inlineBlock, inlinePoint, mapper,
@@ -426,8 +425,8 @@ LogicalResult mlir::inlineRegion(
426425

427426
LogicalResult mlir::inlineRegion(
428427
InlinerInterface &interface,
429-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
430-
Operation *inlinePoint, ValueRange inlinedOperands,
428+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
429+
Region *src, Operation *inlinePoint, ValueRange inlinedOperands,
431430
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
432431
bool shouldCloneInlinedRegion) {
433432
return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(),
@@ -437,10 +436,10 @@ LogicalResult mlir::inlineRegion(
437436

438437
LogicalResult mlir::inlineRegion(
439438
InlinerInterface &interface,
440-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src,
441-
Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands,
442-
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
443-
bool shouldCloneInlinedRegion) {
439+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
440+
Region *src, Block *inlineBlock, Block::iterator inlinePoint,
441+
ValueRange inlinedOperands, ValueRange resultsToReplace,
442+
std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
444443
return inlineRegionImpl(interface, cloneCallback, src, inlineBlock,
445444
inlinePoint, inlinedOperands, resultsToReplace,
446445
inlineLoc, shouldCloneInlinedRegion);
@@ -474,11 +473,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
474473
/// failure, no changes are made to the module. 'shouldCloneInlinedRegion'
475474
/// corresponds to whether the source region should be cloned into the 'call' or
476475
/// spliced directly.
477-
LogicalResult
478-
mlir::inlineCall(InlinerInterface &interface,
479-
function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
480-
CallOpInterface call, CallableOpInterface callable,
481-
Region *src, bool shouldCloneInlinedRegion) {
476+
LogicalResult mlir::inlineCall(
477+
InlinerInterface &interface,
478+
function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback,
479+
CallOpInterface call, CallableOpInterface callable, Region *src,
480+
bool shouldCloneInlinedRegion) {
482481
// We expect the region to have at least one block.
483482
if (src->empty())
484483
return failure();

mlir/test/lib/Transforms/TestInliningCallback.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ struct InlinerCallback
6161
src->cloneInto(&region, mapper);
6262

6363
// Split block before scf operation.
64-
Block *continueBlock =
65-
inlineBlock->splitBlock(executeRegionOp.getOperation());
64+
inlineBlock->splitBlock(executeRegionOp.getOperation());
6665

6766
// Replace all test.return with scf.yield
6867
for (mlir::Block &block : region) {

0 commit comments

Comments
 (0)