Skip to content

Commit 730f4a1

Browse files
authored
[mlir][spirv] Split header and merge block in mlir.selections (llvm#134875)
In the example below with the current code the first selection construct (`if`/`else` in GLSL for simplicity) share its merge block with a header block of the second construct. ``` bool _115; if (_107) { // ... _115 = _200 < _174; } else { _115 = _107; } bool _123; if (_115) { // ... _123 = _213 < _174; } else { _123 = _115; } ``` This results in a malformed nesting of `mlir.selection` instructions where one selection ends up inside a header block of another selection construct. For example: ``` %61 = spirv.mlir.selection -> i1 { %80 = spirv.mlir.selection -> i1 { spirv.BranchConditional %60, ^bb1, ^bb2(%60 : i1) ^bb1: // pred: ^bb0 // ... spirv.Branch ^bb2(%101 : i1) ^bb2(%102: i1): // 2 preds: ^bb0, ^bb1 spirv.mlir.merge %102 : i1 } spirv.BranchConditional %80, ^bb1, ^bb2(%80 : i1) ^bb1: // pred: ^bb0 // ... spirv.Branch ^bb2(%90 : i1) ^bb2(%91: i1): // 2 preds: ^bb0, ^bb1 spirv.mlir.merge %91 : i1 } ``` This change ensures that the merge block of one selection is not a header block of another, splitting blocks if necessary. The existing block splitting mechanism is updated to handle this case.
1 parent 8dc89e3 commit 730f4a1

File tree

3 files changed

+87
-15
lines changed

3 files changed

+87
-15
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,23 +2300,22 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
23002300
if (!isa<spirv::BranchConditionalOp>(terminator))
23012301
continue;
23022302

2303-
// Do not split blocks that only contain a conditional branch, i.e., block
2304-
// size is <= 1.
2305-
if (block->begin() != block->end() &&
2306-
std::next(block->begin()) != block->end()) {
2303+
// Check if the current header block is a merge block of another construct.
2304+
bool splitHeaderMergeBlock = false;
2305+
for (const auto &[_, mergeInfo] : blockMergeInfo) {
2306+
if (mergeInfo.mergeBlock == block)
2307+
splitHeaderMergeBlock = true;
2308+
}
2309+
2310+
// Do not split a block that only contains a conditional branch, unless it
2311+
// is also a merge block of another construct - in that case we want to
2312+
// split the block. We do not want two constructs to share header / merge
2313+
// block.
2314+
if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
23072315
Block *newBlock = block->splitBlock(terminator);
23082316
OpBuilder builder(block, block->end());
23092317
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
23102318

2311-
// If the split block was a merge block of another region we need to
2312-
// update the map.
2313-
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
2314-
auto &[ignore, mergeInfo] = *it;
2315-
if (mergeInfo.mergeBlock == block) {
2316-
mergeInfo.mergeBlock = newBlock;
2317-
}
2318-
}
2319-
23202319
// After splitting we need to update the map to use the new block as a
23212320
// header.
23222321
blockMergeInfo.erase(block);

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,10 @@ class Deserializer {
246246
return opBuilder.getStringAttr(attrName);
247247
}
248248

249-
// Move a conditional branch into a separate basic block to avoid sinking
250-
// defs that are required outside a selection region.
249+
/// Move a conditional branch into a separate basic block to avoid unnecessary
250+
/// sinking of defs that may be required outside a selection region. This
251+
/// function also ensures that a single block cannot be a header block of one
252+
/// selection construct and the merge block of another.
251253
LogicalResult splitConditionalBlocks();
252254

253255
//===--------------------------------------------------------------------===//
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
2+
3+
; COM: The purpose of this test is to check that in the case where two selections
4+
; COM: regions share a header / merge block, this block is split and the selection
5+
; COM: regions are not incorrectly nested.
6+
7+
; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
8+
; CHECK: spirv.func @main() "None" {
9+
; CHECK: spirv.mlir.selection {
10+
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
11+
; CHECK-NEXT: ^[[bb:.+]]
12+
; CHECK: spirv.Branch ^[[bb:.+]]
13+
; CHECK-NEXT: ^[[bb:.+]]:
14+
; CHECK-NEXT: spirv.mlir.merge
15+
; CHECK-NEXT: }
16+
; CHECK: spirv.mlir.selection {
17+
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
18+
; CHECK-NEXT: ^[[bb:.+]]
19+
; CHECK: spirv.Branch ^[[bb:.+]]
20+
; CHECK-NEXT: ^[[bb:.+]]:
21+
; CHECK-NEXT: spirv.mlir.merge
22+
; CHECK-NEXT: }
23+
; CHECK: spirv.Return
24+
; CHECK-NEXT: }
25+
; CHECK: }
26+
27+
OpCapability Shader
28+
%2 = OpExtInstImport "GLSL.std.450"
29+
OpMemoryModel Logical GLSL450
30+
OpEntryPoint Fragment %main "main" %colorOut
31+
OpExecutionMode %main OriginUpperLeft
32+
OpDecorate %colorOut Location 0
33+
%void = OpTypeVoid
34+
%4 = OpTypeFunction %void
35+
%float = OpTypeFloat 32
36+
%v4float = OpTypeVector %float 4
37+
%fun_v4float = OpTypePointer Function %v4float
38+
%float_1 = OpConstant %float 1
39+
%float_0 = OpConstant %float 0
40+
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
41+
%out_v4float = OpTypePointer Output %v4float
42+
%colorOut = OpVariable %out_v4float Output
43+
%uint = OpTypeInt 32 0
44+
%uint_0 = OpConstant %uint 0
45+
%out_float = OpTypePointer Output %float
46+
%bool = OpTypeBool
47+
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
48+
%main = OpFunction %void None %4
49+
%6 = OpLabel
50+
%color = OpVariable %fun_v4float Function
51+
OpStore %color %13
52+
%19 = OpAccessChain %out_float %colorOut %uint_0
53+
%20 = OpLoad %float %19
54+
%22 = OpFOrdEqual %bool %20 %float_1
55+
OpSelectionMerge %24 None
56+
OpBranchConditional %22 %23 %24
57+
%23 = OpLabel
58+
OpStore %color %25
59+
OpBranch %24
60+
%24 = OpLabel
61+
%30 = OpFOrdEqual %bool %20 %float_1
62+
OpSelectionMerge %32 None
63+
OpBranchConditional %30 %31 %32
64+
%31 = OpLabel
65+
OpStore %color %25
66+
OpBranch %32
67+
%32 = OpLabel
68+
%26 = OpLoad %v4float %color
69+
OpStore %colorOut %26
70+
OpReturn
71+
OpFunctionEnd

0 commit comments

Comments
 (0)