@@ -120,6 +120,91 @@ class DeallocOpConversion
120
120
return success ();
121
121
}
122
122
123
+ // / A special case lowering for the deallocation operation with exactly one
124
+ // / memref, but arbitrary number of retained values. This avoids the helper
125
+ // / function that the general case needs and thus also avoids storing indices
126
+ // / to specifically allocated memrefs. The size of the code produced by this
127
+ // / lowering is linear to the number of retained values.
128
+ // /
129
+ // / Example:
130
+ // / ```mlir
131
+ // / %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
132
+ // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
133
+ // / return %0#0, %0#1 : i1, i1
134
+ // / ```
135
+ // / ```mlir
136
+ // / %m_base_pointer = memref.extract_aligned_pointer_as_index %m
137
+ // / %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
138
+ // / %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
139
+ // / %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
140
+ // / %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
141
+ // / %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
142
+ // / %should_dealloc = arith.andi %not_retained, %cond : i1
143
+ // / scf.if %should_dealloc {
144
+ // / memref.dealloc %m : memref<2xf32>
145
+ // / }
146
+ // / %true = arith.constant true
147
+ // / %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
148
+ // / %r0_ownership = arith.andi %r0_does_alias, %cond : i1
149
+ // / %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
150
+ // / %r1_ownership = arith.andi %r1_does_alias, %cond : i1
151
+ // / return %r0_ownership, %r1_ownership : i1, i1
152
+ // / ```
153
+ LogicalResult rewriteOneMemrefMultipleRetainCase (
154
+ bufferization::DeallocOp op, OpAdaptor adaptor,
155
+ ConversionPatternRewriter &rewriter) const {
156
+ assert (adaptor.getMemrefs ().size () == 1 && " expected only one memref" );
157
+
158
+ // Compute the base pointer indices, compare all retained indices to the
159
+ // memref index to check if they alias.
160
+ SmallVector<Value> doesNotAliasList;
161
+ Value memrefAsIdx = rewriter.create <memref::ExtractAlignedPointerAsIndexOp>(
162
+ op->getLoc (), adaptor.getMemrefs ()[0 ]);
163
+ for (Value retained : adaptor.getRetained ()) {
164
+ Value retainedAsIdx =
165
+ rewriter.create <memref::ExtractAlignedPointerAsIndexOp>(op->getLoc (),
166
+ retained);
167
+ Value doesNotAlias = rewriter.create <arith::CmpIOp>(
168
+ op->getLoc (), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
169
+ doesNotAliasList.push_back (doesNotAlias);
170
+ }
171
+
172
+ // AND-reduce the list of booleans from above.
173
+ Value prev = doesNotAliasList.front ();
174
+ for (Value doesNotAlias : ArrayRef (doesNotAliasList).drop_front ())
175
+ prev = rewriter.create <arith::AndIOp>(op->getLoc (), prev, doesNotAlias);
176
+
177
+ // Also consider the condition given by the dealloc operation and perform a
178
+ // conditional deallocation guarded by that value.
179
+ Value shouldDealloc = rewriter.create <arith::AndIOp>(
180
+ op->getLoc (), prev, adaptor.getConditions ()[0 ]);
181
+
182
+ rewriter.create <scf::IfOp>(
183
+ op.getLoc (), shouldDealloc, [&](OpBuilder &builder, Location loc) {
184
+ builder.create <memref::DeallocOp>(loc, adaptor.getMemrefs ()[0 ]);
185
+ builder.create <scf::YieldOp>(loc);
186
+ });
187
+
188
+ // Compute the replacement values for the dealloc operation results. This
189
+ // inserts an already canonicalized form of
190
+ // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
191
+ // value r.
192
+ SmallVector<Value> replacements;
193
+ Value trueVal = rewriter.create <arith::ConstantOp>(
194
+ op->getLoc (), rewriter.getBoolAttr (true ));
195
+ for (Value doesNotAlias : doesNotAliasList) {
196
+ Value aliases =
197
+ rewriter.create <arith::XOrIOp>(op->getLoc (), doesNotAlias, trueVal);
198
+ Value result = rewriter.create <arith::AndIOp>(op->getLoc (), aliases,
199
+ adaptor.getConditions ()[0 ]);
200
+ replacements.push_back (result);
201
+ }
202
+
203
+ rewriter.replaceOp (op, replacements);
204
+
205
+ return success ();
206
+ }
207
+
123
208
// / Lowering that supports all features the dealloc operation has to offer. It
124
209
// / computes the base pointer of each memref (as an index), stores it in a
125
210
// / new memref helper structure and passes it to the helper function generated
@@ -310,12 +395,20 @@ class DeallocOpConversion
310
395
matchAndRewrite (bufferization::DeallocOp op, OpAdaptor adaptor,
311
396
ConversionPatternRewriter &rewriter) const override {
312
397
// Lower the trivial case.
313
- if (adaptor.getMemrefs ().empty ())
314
- return rewriter.eraseOp (op), success ();
398
+ if (adaptor.getMemrefs ().empty ()) {
399
+ Value falseVal = rewriter.create <arith::ConstantOp>(
400
+ op.getLoc (), rewriter.getBoolAttr (false ));
401
+ rewriter.replaceOp (
402
+ op, SmallVector<Value>(adaptor.getRetained ().size (), falseVal));
403
+ return success ();
404
+ }
315
405
316
406
if (adaptor.getMemrefs ().size () == 1 && adaptor.getRetained ().empty ())
317
407
return rewriteOneMemrefNoRetainCase (op, adaptor, rewriter);
318
408
409
+ if (adaptor.getMemrefs ().size () == 1 )
410
+ return rewriteOneMemrefMultipleRetainCase (op, adaptor, rewriter);
411
+
319
412
return rewriteGeneralCase (op, adaptor, rewriter);
320
413
}
321
414
@@ -535,8 +628,7 @@ struct BufferizationToMemRefPass
535
628
// Build dealloc helper function if there are deallocs.
536
629
func::FuncOp helperFuncOp;
537
630
getOperation ()->walk ([&](bufferization::DeallocOp deallocOp) {
538
- if (deallocOp.getMemrefs ().size () > 1 ||
539
- !deallocOp.getRetained ().empty ()) {
631
+ if (deallocOp.getMemrefs ().size () > 1 ) {
540
632
helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction (
541
633
builder, getOperation ()->getLoc (), symbolTable);
542
634
return WalkResult::interrupt ();
0 commit comments