Skip to content

Commit 721c5cc

Browse files
authored
[mlir][spirv] Allow yielding values from loop regions (llvm#135344)
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
1 parent 4164741 commit 721c5cc

File tree

7 files changed

+180
-35
lines changed

7 files changed

+180
-35
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
734734
}
735735
```
736736

737+
Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
738+
mechanism allows values defined within the loop region to be used outside of it.
739+
740+
For example
741+
742+
```mlir
743+
%yielded = spirv.mlir.loop -> i32 {
744+
// ...
745+
spirv.mlir.merge %to_yield : i32
746+
}
747+
```
748+
737749
### Block argument for Phi
738750

739751
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
311311
The continue block should be the second to last block and it should have a
312312
branch to the loop header block. The loop continue block should be the only
313313
block, except the entry block, branching to the header block.
314+
315+
Values defined inside the loop regions cannot be directly used
316+
outside of them; however, the loop region can yield values. These values are
317+
yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
314318
}];
315319

316320
let arguments = (ins
317321
SPIRV_LoopControlAttr:$loop_control
318322
);
319323

320-
let results = (outs);
324+
let results = (outs Variadic<AnyType>:$results);
321325

322326
let regions = (region AnyRegion:$body);
323327

324-
let builders = [OpBuilder<(ins)>];
328+
let builders = [
329+
OpBuilder<(ins)>,
330+
OpBuilder<(ins "spirv::LoopControl":$loopControl),
331+
[{
332+
build($_builder, $_state, TypeRange(), loopControl);
333+
}]>
334+
];
325335

326336
let extraClassDeclaration = [{
327337
// Returns the entry block.

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,22 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
229229
if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
230230
result))
231231
return failure();
232+
233+
if (succeeded(parser.parseOptionalArrow()))
234+
if (parser.parseTypeList(result.types))
235+
return failure();
236+
232237
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
233238
}
234239

235240
void LoopOp::print(OpAsmPrinter &printer) {
236241
auto control = getLoopControl();
237242
if (control != spirv::LoopControl::None)
238243
printer << " control(" << spirv::stringifyLoopControl(control) << ")";
244+
if (getNumResults() > 0) {
245+
printer << " -> ";
246+
printer << getResultTypes();
247+
}
239248
printer << ' ';
240249
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
241250
/*printBlockTerminators=*/true);

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,7 +2003,14 @@ LogicalResult ControlFlowStructurizer::structurize() {
20032003
// block inside the selection (`body.back()`). Values produced by block
20042004
// arguments will be yielded by the selection region. We do not update uses or
20052005
// erase original block arguments yet. It will be done later in the code.
2006-
if (!isLoop) {
2006+
//
2007+
// Code below is not executed for loops as it would interfere with the logic
2008+
// above. Currently block arguments in the merge block are not supported, but
2009+
// instead, the code above copies those arguments from the header block into
2010+
// the merge block. As such, running the code would yield those copied
2011+
// arguments that is most likely not a desired behaviour. This may need to be
2012+
// revisited in the future.
2013+
if (!isLoop)
20072014
for (BlockArgument blockArg : mergeBlock->getArguments()) {
20082015
// Create new block arguments in the last block ("merge block") of the
20092016
// selection region. We create one argument for each argument in
@@ -2013,7 +2020,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
20132020
valuesToYield.push_back(body.back().getArguments().back());
20142021
outsideUses.push_back(blockArg);
20152022
}
2016-
}
20172023

20182024
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
20192025
// cleaned up.
@@ -2025,32 +2031,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
20252031

20262032
// All internal uses should be removed from original blocks by now, so
20272033
// whatever is left is an outside use and will need to be yielded from
2028-
// the newly created selection region.
2029-
if (!isLoop) {
2030-
for (Block *block : constructBlocks) {
2031-
for (Operation &op : *block) {
2032-
if (!op.use_empty())
2033-
for (Value result : op.getResults()) {
2034-
valuesToYield.push_back(mapper.lookupOrNull(result));
2035-
outsideUses.push_back(result);
2036-
}
2037-
}
2038-
for (BlockArgument &arg : block->getArguments()) {
2039-
if (!arg.use_empty()) {
2040-
valuesToYield.push_back(mapper.lookupOrNull(arg));
2041-
outsideUses.push_back(arg);
2034+
// the newly created selection / loop region.
2035+
for (Block *block : constructBlocks) {
2036+
for (Operation &op : *block) {
2037+
if (!op.use_empty())
2038+
for (Value result : op.getResults()) {
2039+
valuesToYield.push_back(mapper.lookupOrNull(result));
2040+
outsideUses.push_back(result);
20422041
}
2042+
}
2043+
for (BlockArgument &arg : block->getArguments()) {
2044+
if (!arg.use_empty()) {
2045+
valuesToYield.push_back(mapper.lookupOrNull(arg));
2046+
outsideUses.push_back(arg);
20432047
}
20442048
}
20452049
}
20462050

20472051
assert(valuesToYield.size() == outsideUses.size());
20482052

2049-
// If we need to yield any values from the selection region we will take
2050-
// care of it here.
2051-
if (!isLoop && !valuesToYield.empty()) {
2053+
// If we need to yield any values from the selection / loop region we will
2054+
// take care of it here.
2055+
if (!valuesToYield.empty()) {
20522056
LLVM_DEBUG(logger.startLine()
2053-
<< "[cf] yielding values from the selection region\n");
2057+
<< "[cf] yielding values from the selection / loop region\n");
20542058

20552059
// Update `mlir.merge` with values to be yield.
20562060
auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2063,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
20592063
merge->setOperands(valuesToYield);
20602064

20612065
// MLIR does not allow changing the number of results of an operation, so
2062-
// we create a new SelectionOp with required list of results and move
2063-
// the region from the initial SelectionOp. The initial operation is then
2064-
// removed. Since we move the region to the new op all links between blocks
2065-
// and remapping we have previously done should be preserved.
2066+
// we create a new SelectionOp / LoopOp with required list of results and
2067+
// move the region from the initial SelectionOp / LoopOp. The initial
2068+
// operation is then removed. Since we move the region to the new op all
2069+
// links between blocks and remapping we have previously done should be
2070+
// preserved.
20662071
builder.setInsertionPoint(&mergeBlock->front());
2067-
auto selectionOp = builder.create<spirv::SelectionOp>(
2068-
location, TypeRange(ValueRange(outsideUses)),
2069-
static_cast<spirv::SelectionControl>(control));
2070-
selectionOp->getRegion(0).takeBody(body);
2072+
2073+
Operation *newOp = nullptr;
2074+
2075+
if (isLoop)
2076+
newOp = builder.create<spirv::LoopOp>(
2077+
location, TypeRange(ValueRange(outsideUses)),
2078+
static_cast<spirv::LoopControl>(control));
2079+
else
2080+
newOp = builder.create<spirv::SelectionOp>(
2081+
location, TypeRange(ValueRange(outsideUses)),
2082+
static_cast<spirv::SelectionControl>(control));
2083+
2084+
newOp->getRegion(0).takeBody(body);
20712085

20722086
// Remove initial op and swap the pointer to the newly created one.
20732087
op->erase();
2074-
op = selectionOp;
2088+
op = newOp;
20752089

2076-
// Update all outside uses to use results of the SelectionOp and remove
2077-
// block arguments from the original merge block.
2090+
// Update all outside uses to use results of the SelectionOp / LoopOp and
2091+
// remove block arguments from the original merge block.
20782092
for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2079-
outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
2080-
mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2093+
outsideUses[i].replaceAllUsesWith(op->getResult(i));
2094+
2095+
// We do not support block arguments in loop merge block. Also running this
2096+
// function with loop would break some of the loop specific code above
2097+
// dealing with block arguments.
2098+
if (!isLoop)
2099+
mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
20812100
}
20822101

20832102
// Check that whether some op in the to-be-erased blocks still has uses. Those

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
520520
auto mergeID = getBlockID(mergeBlock);
521521
auto loc = loopOp.getLoc();
522522

523+
// Before we do anything replace results of the selection operation with
524+
// values yielded (with `mlir.merge`) from inside the region.
525+
auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
526+
assert(loopOp.getNumResults() == mergeOp.getNumOperands());
527+
for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
528+
loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
529+
523530
// This LoopOp is in some MLIR block with preceding and following ops. In the
524531
// binary format, it should reside in separate SPIR-V blocks from its
525532
// preceding and following ops. So we need to emit unconditional branches to

mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
426426

427427
// -----
428428

429+
func.func @loop_yield(%count : i32) -> () {
430+
%zero = spirv.Constant 0: i32
431+
%one = spirv.Constant 1: i32
432+
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
433+
434+
// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
435+
%final_i = spirv.mlir.loop -> i32 {
436+
// CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
437+
spirv.Branch ^header(%zero: i32)
438+
439+
// CHECK-NEXT: ^bb1({{%.*}}: i32):
440+
^header(%i : i32):
441+
%cmp = spirv.SLessThan %i, %count : i32
442+
// CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
443+
spirv.BranchConditional %cmp, ^body, ^merge
444+
445+
// CHECK-NEXT: ^bb2:
446+
^body:
447+
// CHECK-NEXT: spirv.Branch ^bb3
448+
spirv.Branch ^continue
449+
450+
// CHECK-NEXT: ^bb3:
451+
^continue:
452+
%new_i = spirv.IAdd %i, %one : i32
453+
// CHECK: spirv.Branch ^bb1({{%.*}}: i32)
454+
spirv.Branch ^header(%new_i: i32)
455+
456+
// CHECK-NEXT: ^bb4:
457+
^merge:
458+
// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
459+
spirv.mlir.merge %i : i32
460+
}
461+
462+
// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
463+
spirv.Store "Function" %var, %final_i : i32
464+
465+
return
466+
}
467+
468+
// -----
469+
429470
//===----------------------------------------------------------------------===//
430471
// spirv.mlir.merge
431472
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/loop.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
288288
spirv.Return
289289
}
290290
}
291+
292+
// -----
293+
294+
// Loop yielding values
295+
296+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
297+
spirv.func @loop_yield(%count : i32) -> () "None" {
298+
%zero = spirv.Constant 0: i32
299+
%one = spirv.Constant 1: i32
300+
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
301+
302+
// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
303+
%final_i = spirv.mlir.loop -> i32 {
304+
// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
305+
spirv.Branch ^header(%zero: i32)
306+
307+
// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
308+
^header(%i : i32):
309+
%cmp = spirv.SLessThan %i, %count : i32
310+
// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
311+
spirv.BranchConditional %cmp, ^body, ^merge
312+
313+
// CHECK-NEXT: ^[[BODY:.+]]:
314+
^body:
315+
// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
316+
spirv.Branch ^continue
317+
318+
// CHECK-NEXT: ^[[CONTINUE:.+]]:
319+
^continue:
320+
%new_i = spirv.IAdd %i, %one : i32
321+
// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
322+
spirv.Branch ^header(%new_i: i32)
323+
324+
// CHECK-NEXT: ^[[MERGE:.+]]:
325+
^merge:
326+
// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
327+
spirv.mlir.merge %i : i32
328+
// CHECK-NEXT: }
329+
}
330+
331+
// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
332+
spirv.Store "Function" %var, %final_i : i32
333+
334+
// CHECK-NEXT: spirv.Return
335+
spirv.Return
336+
}
337+
}

0 commit comments

Comments
 (0)