Skip to content

Commit 6c2f847

Browse files
[mlir][Transforms] Dialect Conversion: Add 1:N support to remapInput (llvm#131454)
This commit adds 1:N support to `SignatureConversion::remapInputs`. This API allows users to replace a block argument with multiple replacement values. (And the block argument is dropped.) The API already supported "bbarg --> multiple bbargs" mappings, but "bbarg --> multiple SSA values" was missing. --------- Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
1 parent 911953a commit 6c2f847

File tree

6 files changed

+75
-33
lines changed

6 files changed

+75
-33
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,14 @@ class TypeConverter {
6565
SignatureConversion(unsigned numOrigInputs)
6666
: remappedInputs(numOrigInputs) {}
6767

68-
/// This struct represents a range of new types or a single value that
68+
/// This struct represents a range of new types or a range of values that
6969
/// remaps an existing signature input.
7070
struct InputMapping {
7171
size_t inputNo, size;
72-
Value replacementValue;
72+
SmallVector<Value, 1> replacementValues;
73+
74+
/// Return "true" if this input was replaces with one or multiple values.
75+
bool replacedWithValues() const { return !replacementValues.empty(); }
7376
};
7477

7578
/// Return the argument types for the new signature.
@@ -92,9 +95,9 @@ class TypeConverter {
9295
/// used if the new types are not intended to remap an existing input.
9396
void addInputs(ArrayRef<Type> types);
9497

95-
/// Remap an input of the original signature to another `replacement`
96-
/// value. This drops the original argument.
97-
void remapInput(unsigned origInputNo, Value replacement);
98+
/// Remap an input of the original signature to `replacements`
99+
/// values. This drops the original argument.
100+
void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
98101

99102
private:
100103
/// Remap an input of the original signature with a range of types in the

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
274274
// and canonicalize that away later.
275275
Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
276276
auto type = cast<MemRefType>(attribution.getType());
277-
auto descr = MemRefDescriptor::fromStaticShape(
277+
Value descr = MemRefDescriptor::fromStaticShape(
278278
rewriter, loc, *getTypeConverter(), type, memory);
279279
signatureConversion.remapInput(numProperArguments + idx, descr);
280280
}
@@ -303,7 +303,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
303303
alignment = alignAttr.getInt();
304304
Value allocated = rewriter.create<LLVM::AllocaOp>(
305305
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
306-
auto descr = MemRefDescriptor::fromStaticShape(
306+
Value descr = MemRefDescriptor::fromStaticShape(
307307
rewriter, loc, *getTypeConverter(), type, allocated);
308308
signatureConversion.remapInput(
309309
numProperArguments + numWorkgroupAttributions + idx, descr);

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13411341
rewriter.getUnknownLoc());
13421342
for (unsigned i = 0; i < origArgCount; ++i) {
13431343
auto inputMap = signatureConversion.getInputMapping(i);
1344-
if (!inputMap || inputMap->replacementValue)
1344+
if (!inputMap || inputMap->replacedWithValues())
13451345
continue;
13461346
Location origLoc = block->getArgument(i).getLoc();
13471347
for (unsigned j = 0; j < inputMap->size; ++j)
@@ -1390,12 +1390,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13901390
continue;
13911391
}
13921392

1393-
if (Value repl = inputMap->replacementValue) {
1394-
// This block argument was dropped and a replacement value was provided.
1393+
if (inputMap->replacedWithValues()) {
1394+
// This block argument was dropped and replacement values were provided.
13951395
assert(inputMap->size == 0 &&
13961396
"invalid to provide a replacement value when the argument isn't "
13971397
"dropped");
1398-
mapping.map(origArg, repl);
1398+
mapping.map(origArg, inputMap->replacementValues);
13991399
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
14001400
continue;
14011401
}
@@ -2807,14 +2807,15 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
28072807
assert(!remappedInputs[origInputNo] && "input has already been remapped");
28082808
assert(newInputCount != 0 && "expected valid input count");
28092809
remappedInputs[origInputNo] =
2810-
InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
2810+
InputMapping{newInputNo, newInputCount, /*replacementValues=*/{}};
28112811
}
28122812

2813-
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
2814-
Value replacementValue) {
2813+
void TypeConverter::SignatureConversion::remapInput(
2814+
unsigned origInputNo, ArrayRef<Value> replacements) {
28152815
assert(!remappedInputs[origInputNo] && "input has already been remapped");
2816-
remappedInputs[origInputNo] =
2817-
InputMapping{origInputNo, /*size=*/0, replacementValue};
2816+
remappedInputs[origInputNo] = InputMapping{
2817+
origInputNo, /*size=*/0,
2818+
SmallVector<Value, 1>(replacements.begin(), replacements.end())};
28182819
}
28192820

28202821
LogicalResult TypeConverter::convertType(Type t,

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,11 +472,32 @@ func.func @circular_mapping() {
472472

473473
// -----
474474

475-
func.func @test_1_to_n_block_signature_conversion() {
476-
"test.duplicate_block_args"() ({
475+
// CHECK-LABEL: func @test_duplicate_block_arg()
476+
// CHECK: test.convert_block_args is_legal duplicate {
477+
// CHECK: ^{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64):
478+
// CHECK: "test.valid"(%[[arg0]], %[[arg1]])
479+
// CHECK: }
480+
func.func @test_duplicate_block_arg() {
481+
test.convert_block_args duplicate {
477482
^bb0(%arg0: i64):
478483
"test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
479-
}) {} : () -> ()
484+
} : () -> ()
485+
"test.return"() : () -> ()
486+
}
487+
488+
// -----
489+
490+
// CHECK-LABEL: func @test_remap_block_arg()
491+
// CHECK: %[[repl:.*]] = "test.legal_op"() : () -> i32
492+
// CHECK: test.convert_block_args %[[repl]] is_legal replace_with_operand {
493+
// CHECK-NEXT: "test.valid"(%[[repl]], %[[repl]])
494+
// CHECK: }
495+
func.func @test_remap_block_arg() {
496+
%0 = "test.legal_op"() : () -> (i32)
497+
test.convert_block_args %0 replace_with_operand {
498+
^bb0(%arg0: i32):
499+
"test.repetitive_1_to_n_consumer"(%arg0) : (i32) -> ()
500+
} : (i32) -> ()
480501
"test.return"() : () -> ()
481502
}
482503

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,9 +1940,17 @@ def LegalOpC : TEST_Op<"legal_op_c">,
19401940
Arguments<(ins I32)>, Results<(outs I32)>;
19411941
def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
19421942

1943-
def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
1944-
let arguments = (ins UnitAttr:$is_legal);
1943+
def ConvertBlockArgsOp : TEST_Op<"convert_block_args", [SingleBlock]> {
1944+
let arguments = (ins UnitAttr:$is_legal, UnitAttr:$replace_with_operand,
1945+
UnitAttr:$duplicate, Optional<AnyType>:$val);
19451946
let regions = (region SizedRegion<1>:$body);
1947+
let assemblyFormat = [{
1948+
$val
1949+
(`is_legal` $is_legal^)?
1950+
(`duplicate` $duplicate^)?
1951+
(`replace_with_operand` $replace_with_operand^)?
1952+
$body attr-dict `:` functional-type(operands, results)
1953+
}];
19461954
}
19471955

19481956
// Check that the conversion infrastructure can properly undo the creation of

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,22 +1193,31 @@ class TestEraseOp : public ConversionPattern {
11931193
}
11941194
};
11951195

1196-
/// This pattern matches a test.duplicate_block_args op and duplicates all
1197-
/// block arguments.
1198-
class TestDuplicateBlockArgs
1199-
: public OpConversionPattern<DuplicateBlockArgsOp> {
1200-
using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
1196+
/// This pattern matches a test.convert_block_args op. It either:
1197+
/// a) Duplicates all block arguments,
1198+
/// b) or: drops all block arguments and replaces each with 2x the first
1199+
/// operand.
1200+
class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
1201+
using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
12011202

12021203
LogicalResult
1203-
matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
1204+
matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
12041205
ConversionPatternRewriter &rewriter) const override {
12051206
if (op.getIsLegal())
12061207
return failure();
1207-
rewriter.startOpModification(op);
12081208
Block *body = &op.getBody().front();
12091209
TypeConverter::SignatureConversion result(body->getNumArguments());
1210-
for (auto it : llvm::enumerate(body->getArgumentTypes()))
1211-
result.addInputs(it.index(), {it.value(), it.value()});
1210+
for (auto it : llvm::enumerate(body->getArgumentTypes())) {
1211+
if (op.getReplaceWithOperand()) {
1212+
result.remapInput(it.index(), {adaptor.getVal(), adaptor.getVal()});
1213+
} else if (op.getDuplicate()) {
1214+
result.addInputs(it.index(), {it.value(), it.value()});
1215+
} else {
1216+
// No action specified. Pattern does not apply.
1217+
return failure();
1218+
}
1219+
}
1220+
rewriter.startOpModification(op);
12121221
rewriter.applySignatureConversion(body, result, getTypeConverter());
12131222
op.setIsLegal(true);
12141223
rewriter.finalizeOpModification(op);
@@ -1355,7 +1364,7 @@ struct TestLegalizePatternDriver
13551364
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
13561365
TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
13571366
&getContext(), converter);
1358-
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
1367+
patterns.add<TestConvertBlockArgs>(converter, &getContext());
13591368
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
13601369
converter);
13611370
mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1406,8 +1415,8 @@ struct TestLegalizePatternDriver
14061415
target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
14071416
[](TestOpInPlaceSelfFold op) { return op.getFolded(); });
14081417

1409-
target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
1410-
[](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
1418+
target.addDynamicallyLegalOp<ConvertBlockArgsOp>(
1419+
[](ConvertBlockArgsOp op) { return op.getIsLegal(); });
14111420

14121421
// Handle a partial conversion.
14131422
if (mode == ConversionMode::Partial) {

0 commit comments

Comments
 (0)