Skip to content

Commit 88b7e8e

Browse files
[mlir][SCF] Add an scf.take_assumed_branch transform op.
Given an scf.if conditional, using this transformation is akin to injecting user-specified information that it is always safe to execute only the specified `if` or `else` branch. This is achieved by just replacing the scf.if by the content of one of its branches. This is particularly useful for user-controlled rewriting of conditionals that exist solely to guard against out-of-bounds behavior. At the moment, no assume or assert operation is emitted as it is not always desirable. In the future, this may be controlled by a dedicated attribute. Differential Revision: https://reviews.llvm.org/D148125
1 parent 34f5774 commit 88b7e8e

File tree

4 files changed

+133
-0
lines changed

4 files changed

+133
-0
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class FuncOp;
1919
} // namespace func
2020
namespace scf {
2121
class ForOp;
22+
class IfOp;
2223
} // namespace scf
2324
} // namespace mlir
2425

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,45 @@ def LoopCoalesceOp : Op<Transform_Dialect, "loop.coalesce", [
215215
}];
216216
}
217217

218+
def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
219+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
220+
TransformOpInterface, TransformEachOpTrait]> {
221+
let description = [{
222+
Given an scf.if conditional, inject user-defined information that it is
223+
always safe to execute only the if or else branch.
224+
225+
This is achieved by just replacing the scf.if by the content of one of its
226+
branches.
227+
228+
This is particularly useful for user-controlled rewriting of conditionals
229+
that exist solely to guard against out-of-bounds behavior.
230+
231+
At the moment, no assume or assert operation is emitted as it is not always
232+
desirable. In the future, this may be controlled by a dedicated attribute.
233+
234+
#### Return modes
235+
236+
The transform only consumes its operand and does not produce any result.
237+
The transform definitely fails if `take_else_branch` is specified and the
238+
`else` region is empty.
239+
}];
240+
let arguments = (ins TransformHandleTypeInterface:$target,
241+
OptionalAttr<UnitAttr>:$take_else_branch);
242+
let results = (outs);
243+
244+
let assemblyFormat = [{
245+
$target
246+
(`take_else_branch` $take_else_branch^)?
247+
attr-dict
248+
`:` functional-type(operands, results)
249+
}];
250+
251+
let extraClassDeclaration = [{
252+
::mlir::DiagnosedSilenceableFailure applyToOne(
253+
::mlir::scf::IfOp ifOp,
254+
::mlir::transform::ApplyToEachResultList &results,
255+
::mlir::transform::TransformState &state);
256+
}];
257+
}
258+
218259
#endif // SCF_TRANSFORM_OPS

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/SCF/Utils/Utils.h"
1717
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1818
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
19+
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1920
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2021

2122
using namespace mlir;
@@ -245,6 +246,46 @@ transform::LoopCoalesceOp::applyToOne(Operation *op,
245246
return DiagnosedSilenceableFailure::success();
246247
}
247248

249+
//===----------------------------------------------------------------------===//
250+
// TakeAssumedBranchOp
251+
//===----------------------------------------------------------------------===//
252+
/// Replaces the given op with the contents of the given single-block region,
253+
/// using the operands of the block terminator to replace operation results.
254+
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
255+
Region &region) {
256+
assert(llvm::hasSingleElement(region) && "expected single-region block");
257+
Block *block = &region.front();
258+
Operation *terminator = block->getTerminator();
259+
ValueRange results = terminator->getOperands();
260+
rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
261+
rewriter.replaceOp(op, results);
262+
rewriter.eraseOp(terminator);
263+
}
264+
265+
DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
266+
scf::IfOp ifOp, transform::ApplyToEachResultList &results,
267+
transform::TransformState &state) {
268+
TrackingListener listener(state, *this);
269+
IRRewriter rewriter(ifOp->getContext(), &listener);
270+
rewriter.setInsertionPoint(ifOp);
271+
272+
Region &region =
273+
getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
274+
if (!llvm::hasSingleElement(region)) {
275+
return emitDefiniteFailure()
276+
<< "requires an scf.if op with a single-block "
277+
<< ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
278+
}
279+
replaceOpWithRegion(rewriter, ifOp, region);
280+
return DiagnosedSilenceableFailure::success();
281+
}
282+
283+
void transform::TakeAssumedBranchOp::getEffects(
284+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
285+
consumesHandle(getTarget(), effects);
286+
modifiesPayload(effects);
287+
}
288+
248289
//===----------------------------------------------------------------------===//
249290
// Transform op registration
250291
//===----------------------------------------------------------------------===//
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics --allow-unregistered-dialect | FileCheck %s
2+
3+
func.func @if_no_else(%cond: i1, %a: index, %b: memref<?xf32>, %c: i8) {
4+
scf.if %cond {
5+
"some_op"(%cond, %b) : (i1, memref<?xf32>) -> ()
6+
scf.yield
7+
}
8+
return
9+
}
10+
11+
transform.sequence failures(propagate) {
12+
^bb0(%arg1: !transform.any_op):
13+
%if = transform.structured.match ops{["scf.if"]} in %arg1
14+
: (!transform.any_op) -> !transform.any_op
15+
16+
// expected-error @+1 {{requires an scf.if op with a single-block `else` region}}
17+
transform.scf.take_assumed_branch %if take_else_branch
18+
: (!transform.any_op) -> ()
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: tile_tensor_pad
24+
func.func @tile_tensor_pad(
25+
%arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
26+
-> tensor<20x40xf32>
27+
{
28+
// CHECK: scf.forall
29+
// CHECK-NOT: scf.if
30+
// CHECK-NOT: tensor.generate
31+
// CHECK-NOT: else
32+
// CHECK: tensor.pad {{.*}} nofold
33+
%0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
34+
^bb0(%arg9: index, %arg10: index):
35+
tensor.yield %cst : f32
36+
} : tensor<?x?xf32> to tensor<20x40xf32>
37+
return %0 : tensor<20x40xf32>
38+
}
39+
40+
transform.sequence failures(propagate) {
41+
^bb0(%arg1: !transform.any_op):
42+
%0 = transform.structured.match ops{["tensor.pad"]} in %arg1
43+
: (!transform.any_op) -> !pdl.operation
44+
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
45+
46+
%if = transform.structured.match ops{["scf.if"]} in %arg1
47+
: (!transform.any_op) -> !transform.any_op
48+
transform.scf.take_assumed_branch %if take_else_branch
49+
: (!transform.any_op) -> ()
50+
}

0 commit comments

Comments
 (0)