Skip to content

Commit 660fded

Browse files
committed
[mlir][bufferization] Add specialized lowering for deallocs with one memref but arbitrary retains
It is often the case that many values in the `memrefs` operand list can be split off to speparate dealloc operations by the `--buffer-deallocation-simplification` pass, however, the retain list has to be preserved initially. Further canonicalization can often trim it down considerable, but some retains may remain. In those cases, the general lowering would be chosen, but is very inefficient. This commit adds another lowering for those cases which avoids allocation of auxillary memrefs and the helper function while still producing code that is linear in the number of operands of the dealloc operation. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157692
1 parent 6238b8e commit 660fded

File tree

2 files changed

+136
-5
lines changed

2 files changed

+136
-5
lines changed

mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,91 @@ class DeallocOpConversion
120120
return success();
121121
}
122122

123+
/// A special case lowering for the deallocation operation with exactly one
124+
/// memref, but arbitrary number of retained values. This avoids the helper
125+
/// function that the general case needs and thus also avoids storing indices
126+
/// to specifically allocated memrefs. The size of the code produced by this
127+
/// lowering is linear to the number of retained values.
128+
///
129+
/// Example:
130+
/// ```mlir
131+
/// %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
132+
// retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
133+
/// return %0#0, %0#1 : i1, i1
134+
/// ```
135+
/// ```mlir
136+
/// %m_base_pointer = memref.extract_aligned_pointer_as_index %m
137+
/// %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
138+
/// %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
139+
/// %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
140+
/// %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
141+
/// %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
142+
/// %should_dealloc = arith.andi %not_retained, %cond : i1
143+
/// scf.if %should_dealloc {
144+
/// memref.dealloc %m : memref<2xf32>
145+
/// }
146+
/// %true = arith.constant true
147+
/// %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
148+
/// %r0_ownership = arith.andi %r0_does_alias, %cond : i1
149+
/// %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
150+
/// %r1_ownership = arith.andi %r1_does_alias, %cond : i1
151+
/// return %r0_ownership, %r1_ownership : i1, i1
152+
/// ```
153+
LogicalResult rewriteOneMemrefMultipleRetainCase(
154+
bufferization::DeallocOp op, OpAdaptor adaptor,
155+
ConversionPatternRewriter &rewriter) const {
156+
assert(adaptor.getMemrefs().size() == 1 && "expected only one memref");
157+
158+
// Compute the base pointer indices, compare all retained indices to the
159+
// memref index to check if they alias.
160+
SmallVector<Value> doesNotAliasList;
161+
Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
162+
op->getLoc(), adaptor.getMemrefs()[0]);
163+
for (Value retained : adaptor.getRetained()) {
164+
Value retainedAsIdx =
165+
rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(),
166+
retained);
167+
Value doesNotAlias = rewriter.create<arith::CmpIOp>(
168+
op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
169+
doesNotAliasList.push_back(doesNotAlias);
170+
}
171+
172+
// AND-reduce the list of booleans from above.
173+
Value prev = doesNotAliasList.front();
174+
for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front())
175+
prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias);
176+
177+
// Also consider the condition given by the dealloc operation and perform a
178+
// conditional deallocation guarded by that value.
179+
Value shouldDealloc = rewriter.create<arith::AndIOp>(
180+
op->getLoc(), prev, adaptor.getConditions()[0]);
181+
182+
rewriter.create<scf::IfOp>(
183+
op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) {
184+
builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]);
185+
builder.create<scf::YieldOp>(loc);
186+
});
187+
188+
// Compute the replacement values for the dealloc operation results. This
189+
// inserts an already canonicalized form of
190+
// `select(does_alias_with_memref(r), memref_cond, false)` for each retained
191+
// value r.
192+
SmallVector<Value> replacements;
193+
Value trueVal = rewriter.create<arith::ConstantOp>(
194+
op->getLoc(), rewriter.getBoolAttr(true));
195+
for (Value doesNotAlias : doesNotAliasList) {
196+
Value aliases =
197+
rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal);
198+
Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases,
199+
adaptor.getConditions()[0]);
200+
replacements.push_back(result);
201+
}
202+
203+
rewriter.replaceOp(op, replacements);
204+
205+
return success();
206+
}
207+
123208
/// Lowering that supports all features the dealloc operation has to offer. It
124209
/// computes the base pointer of each memref (as an index), stores it in a
125210
/// new memref helper structure and passes it to the helper function generated
@@ -310,12 +395,20 @@ class DeallocOpConversion
310395
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
311396
ConversionPatternRewriter &rewriter) const override {
312397
// Lower the trivial case.
313-
if (adaptor.getMemrefs().empty())
314-
return rewriter.eraseOp(op), success();
398+
if (adaptor.getMemrefs().empty()) {
399+
Value falseVal = rewriter.create<arith::ConstantOp>(
400+
op.getLoc(), rewriter.getBoolAttr(false));
401+
rewriter.replaceOp(
402+
op, SmallVector<Value>(adaptor.getRetained().size(), falseVal));
403+
return success();
404+
}
315405

316406
if (adaptor.getMemrefs().size() == 1 && adaptor.getRetained().empty())
317407
return rewriteOneMemrefNoRetainCase(op, adaptor, rewriter);
318408

409+
if (adaptor.getMemrefs().size() == 1)
410+
return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
411+
319412
return rewriteGeneralCase(op, adaptor, rewriter);
320413
}
321414

@@ -535,8 +628,7 @@ struct BufferizationToMemRefPass
535628
// Build dealloc helper function if there are deallocs.
536629
func::FuncOp helperFuncOp;
537630
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
538-
if (deallocOp.getMemrefs().size() > 1 ||
539-
!deallocOp.getRetained().empty()) {
631+
if (deallocOp.getMemrefs().size() > 1) {
540632
helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction(
541633
builder, getOperation()->getLoc(), symbolTable);
542634
return WalkResult::interrupt();

mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,29 @@ func.func @conversion_with_invalid_layout_map(%arg0 : memref<?xf32, strided<[10]
6666
memref.dealloc %arg0 : memref<?xf32, strided<[10], offset: ?>>
6767
return %1 : memref<?xf32, strided<[10], offset: ?>>
6868
}
69+
6970
// -----
7071

7172
// CHECK-LABEL: func @conversion_dealloc_empty
7273
func.func @conversion_dealloc_empty() {
73-
// CHECK-NEXT: return
74+
// CHECK-NOT: bufferization.dealloc
7475
bufferization.dealloc
7576
return
7677
}
7778

7879
// -----
7980

81+
func.func @conversion_dealloc_empty_but_retains(%arg0: memref<2xi32>, %arg1: memref<2xi32>) -> (i1, i1) {
82+
%0:2 = bufferization.dealloc retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
83+
return %0#0, %0#1 : i1, i1
84+
}
85+
86+
// CHECK-LABEL: func @conversion_dealloc_empty
87+
// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
88+
// CHECK-NEXT: return [[FALSE]], [[FALSE]] :
89+
90+
// -----
91+
8092
// CHECK-NOT: func @deallocHelper
8193
// CHECK-LABEL: func @conversion_dealloc_simple
8294
// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32>
@@ -93,6 +105,33 @@ func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) {
93105

94106
// -----
95107

108+
func.func @conversion_dealloc_one_memref_and_multiple_retained(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
109+
%0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
110+
return %0#0, %0#1 : i1, i1
111+
}
112+
113+
// CHECK-LABEL: func @conversion_dealloc_one_memref_and_multiple_retained
114+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xf32>, [[ARG1:%.+]]: memref<1xf32>, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xf32>)
115+
// CHECK-DAG: [[M0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG0]]
116+
// CHECK-DAG: [[R0:%.+]] = memref.extract_aligned_pointer_as_index [[ARG1]]
117+
// CHECK-DAG: [[R1:%.+]] = memref.extract_aligned_pointer_as_index [[ARG3]]
118+
// CHECK-DAG: [[DOES_NOT_ALIAS_R0:%.+]] = arith.cmpi ne, [[M0]], [[R0]] : index
119+
// CHECK-DAG: [[DOES_NOT_ALIAS_R1:%.+]] = arith.cmpi ne, [[M0]], [[R1]] : index
120+
// CHECK: [[NOT_RETAINED:%.+]] = arith.andi [[DOES_NOT_ALIAS_R0]], [[DOES_NOT_ALIAS_R1]]
121+
// CHECK: [[SHOULD_DEALLOC:%.+]] = arith.andi [[NOT_RETAINED]], [[ARG2]]
122+
// CHECK: scf.if [[SHOULD_DEALLOC]]
123+
// CHECK: memref.dealloc [[ARG0]]
124+
// CHECK: }
125+
// CHECK-DAG: [[ALIASES_R0:%.+]] = arith.xori [[DOES_NOT_ALIAS_R0]], %true
126+
// CHECK-DAG: [[ALIASES_R1:%.+]] = arith.xori [[DOES_NOT_ALIAS_R1]], %true
127+
// CHECK-DAG: [[RES0:%.+]] = arith.andi [[ALIASES_R0]], [[ARG2]]
128+
// CHECK-DAG: [[RES1:%.+]] = arith.andi [[ALIASES_R1]], [[ARG2]]
129+
// CHECK: return [[RES0]], [[RES1]]
130+
131+
// CHECK-NOT: func @dealloc_helper
132+
133+
// -----
134+
96135
func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
97136
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
98137
return %0#0, %0#1 : i1, i1

0 commit comments

Comments
 (0)