11
11
// ===----------------------------------------------------------------------===//
12
12
13
13
#include " mlir/Transforms/InliningUtils.h"
14
+ #include " mlir/Transforms/Inliner.h"
14
15
15
16
#include " mlir/IR/Builders.h"
16
17
#include " mlir/IR/IRMapping.h"
@@ -266,10 +267,11 @@ static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
266
267
}
267
268
268
269
static LogicalResult
269
- inlineRegionImpl (InlinerInterface &interface, Region *src, Block *inlineBlock,
270
- Block::iterator inlinePoint, IRMapping &mapper,
271
- ValueRange resultsToReplace, TypeRange regionResultTypes,
272
- std::optional<Location> inlineLoc,
270
+ inlineRegionImpl (InlinerInterface &interface,
271
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
272
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
273
+ IRMapping &mapper, ValueRange resultsToReplace,
274
+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
273
275
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
274
276
assert (resultsToReplace.size () == regionResultTypes.size ());
275
277
// We expect the region to have at least one block.
@@ -296,16 +298,10 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
296
298
if (call && callable)
297
299
handleArgumentImpl (interface, builder, call, callable, mapper);
298
300
299
- // Check to see if the region is being cloned, or moved inline. In either
300
- // case, move the new blocks after the 'insertBlock' to improve IR
301
- // readability.
301
+ // Clone the callee's source into the caller.
302
302
Block *postInsertBlock = inlineBlock->splitBlock (inlinePoint);
303
- if (shouldCloneInlinedRegion)
304
- src->cloneInto (insertRegion, postInsertBlock->getIterator (), mapper);
305
- else
306
- insertRegion->getBlocks ().splice (postInsertBlock->getIterator (),
307
- src->getBlocks (), src->begin (),
308
- src->end ());
303
+ cloneCallback (builder, src, inlineBlock, postInsertBlock, mapper,
304
+ shouldCloneInlinedRegion);
309
305
310
306
// Get the range of newly inserted blocks.
311
307
auto newBlocks = llvm::make_range (std::next (inlineBlock->getIterator ()),
@@ -374,9 +370,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
374
370
}
375
371
376
372
static LogicalResult
377
- inlineRegionImpl (InlinerInterface &interface, Region *src, Block *inlineBlock,
378
- Block::iterator inlinePoint, ValueRange inlinedOperands,
379
- ValueRange resultsToReplace, std::optional<Location> inlineLoc,
373
+ inlineRegionImpl (InlinerInterface &interface,
374
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
375
+ Region *src, Block *inlineBlock, Block::iterator inlinePoint,
376
+ ValueRange inlinedOperands, ValueRange resultsToReplace,
377
+ std::optional<Location> inlineLoc,
380
378
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
381
379
// We expect the region to have at least one block.
382
380
if (src->empty ())
@@ -398,53 +396,54 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
398
396
}
399
397
400
398
// Call into the main region inliner function.
401
- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint, mapper,
402
- resultsToReplace, resultsToReplace.getTypes (),
403
- inlineLoc, shouldCloneInlinedRegion, call);
399
+ return inlineRegionImpl (interface, cloneCallback, src, inlineBlock,
400
+ inlinePoint, mapper, resultsToReplace,
401
+ resultsToReplace.getTypes (), inlineLoc,
402
+ shouldCloneInlinedRegion, call);
404
403
}
405
404
406
- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
407
- Operation *inlinePoint, IRMapping &mapper ,
408
- ValueRange resultsToReplace ,
409
- TypeRange regionResultTypes ,
410
- std::optional<Location> inlineLoc,
411
- bool shouldCloneInlinedRegion) {
412
- return inlineRegion (interface, src, inlinePoint->getBlock (),
405
+ LogicalResult mlir::inlineRegion (
406
+ InlinerInterface &interface ,
407
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
408
+ Operation *inlinePoint, IRMapping &mapper, ValueRange resultsToReplace ,
409
+ TypeRange regionResultTypes, std::optional<Location> inlineLoc,
410
+ bool shouldCloneInlinedRegion) {
411
+ return inlineRegion (interface, cloneCallback, src, inlinePoint->getBlock (),
413
412
++inlinePoint->getIterator (), mapper, resultsToReplace,
414
413
regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
415
414
}
416
- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
417
- Block *inlineBlock,
418
- Block::iterator inlinePoint, IRMapping &mapper ,
419
- ValueRange resultsToReplace ,
420
- TypeRange regionResultTypes ,
421
- std::optional<Location> inlineLoc ,
422
- bool shouldCloneInlinedRegion) {
423
- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint, mapper,
424
- resultsToReplace, regionResultTypes, inlineLoc ,
425
- shouldCloneInlinedRegion);
415
+
416
+ LogicalResult mlir::inlineRegion (
417
+ InlinerInterface &interface ,
418
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
419
+ Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper ,
420
+ ValueRange resultsToReplace, TypeRange regionResultTypes ,
421
+ std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) {
422
+ return inlineRegionImpl (
423
+ interface, cloneCallback, src, inlineBlock, inlinePoint, mapper ,
424
+ resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion);
426
425
}
427
426
428
- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
429
- Operation *inlinePoint ,
430
- ValueRange inlinedOperands ,
431
- ValueRange resultsToReplace ,
432
- std::optional<Location> inlineLoc,
433
- bool shouldCloneInlinedRegion) {
434
- return inlineRegion (interface, src, inlinePoint->getBlock (),
427
+ LogicalResult mlir::inlineRegion (
428
+ InlinerInterface &interface ,
429
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
430
+ Operation *inlinePoint, ValueRange inlinedOperands ,
431
+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
432
+ bool shouldCloneInlinedRegion) {
433
+ return inlineRegion (interface, cloneCallback, src, inlinePoint->getBlock (),
435
434
++inlinePoint->getIterator (), inlinedOperands,
436
435
resultsToReplace, inlineLoc, shouldCloneInlinedRegion);
437
436
}
438
- LogicalResult mlir::inlineRegion (InlinerInterface &interface, Region *src,
439
- Block *inlineBlock,
440
- Block::iterator inlinePoint ,
441
- ValueRange inlinedOperands ,
442
- ValueRange resultsToReplace ,
443
- std::optional<Location> inlineLoc,
444
- bool shouldCloneInlinedRegion) {
445
- return inlineRegionImpl (interface, src, inlineBlock, inlinePoint ,
446
- inlinedOperands, resultsToReplace, inlineLoc ,
447
- shouldCloneInlinedRegion);
437
+
438
+ LogicalResult mlir::inlineRegion (
439
+ InlinerInterface &interface ,
440
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback, Region *src ,
441
+ Block *inlineBlock, Block::iterator inlinePoint, ValueRange inlinedOperands ,
442
+ ValueRange resultsToReplace, std::optional<Location> inlineLoc,
443
+ bool shouldCloneInlinedRegion) {
444
+ return inlineRegionImpl (interface, cloneCallback, src, inlineBlock ,
445
+ inlinePoint, inlinedOperands, resultsToReplace ,
446
+ inlineLoc, shouldCloneInlinedRegion);
448
447
}
449
448
450
449
// / Utility function used to generate a cast operation from the given interface,
@@ -475,10 +474,11 @@ static Value materializeConversion(const DialectInlinerInterface *interface,
475
474
// / failure, no changes are made to the module. 'shouldCloneInlinedRegion'
476
475
// / corresponds to whether the source region should be cloned into the 'call' or
477
476
// / spliced directly.
478
- LogicalResult mlir::inlineCall (InlinerInterface &interface,
479
- CallOpInterface call,
480
- CallableOpInterface callable, Region *src,
481
- bool shouldCloneInlinedRegion) {
477
+ LogicalResult
478
+ mlir::inlineCall (InlinerInterface &interface,
479
+ function_ref<InlinerConfig::CloneCallbackSigTy> cloneCallback,
480
+ CallOpInterface call, CallableOpInterface callable,
481
+ Region *src, bool shouldCloneInlinedRegion) {
482
482
// We expect the region to have at least one block.
483
483
if (src->empty ())
484
484
return failure ();
@@ -552,7 +552,7 @@ LogicalResult mlir::inlineCall(InlinerInterface &interface,
552
552
return cleanupState ();
553
553
554
554
// Attempt to inline the call.
555
- if (failed (inlineRegionImpl (interface, src, call->getBlock (),
555
+ if (failed (inlineRegionImpl (interface, cloneCallback, src, call->getBlock (),
556
556
++call->getIterator (), mapper, callResults,
557
557
callableResultTypes, call.getLoc (),
558
558
shouldCloneInlinedRegion, call)))
0 commit comments