Skip to content

Commit e614e84

Browse files
[mlir][memref] Add runtime verification for memref.dim (llvm#130410)
Add runtime verification for `memref.dim`: check that the index is in bounds. Also simplify the pass pipeline for all memref runtime verification checks.
1 parent 489d1e7 commit e614e84

File tree

7 files changed

+74
-33
lines changed

7 files changed

+74
-33
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@ using namespace mlir;
2323
namespace mlir {
2424
namespace memref {
2525
namespace {
26+
/// Generate a runtime check for lb <= value < ub.
27+
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
28+
Value lb, Value ub) {
29+
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
30+
loc, arith::CmpIPredicate::sge, value, lb);
31+
Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
32+
loc, arith::CmpIPredicate::slt, value, ub);
33+
Value inBounds =
34+
builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
35+
return inBounds;
36+
}
37+
2638
struct CastOpInterface
2739
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
2840
CastOp> {
@@ -172,6 +184,21 @@ struct CopyOpInterface
172184
}
173185
};
174186

187+
struct DimOpInterface
188+
: public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
189+
DimOp> {
190+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
191+
Location loc) const {
192+
auto dimOp = cast<DimOp>(op);
193+
Value rank = builder.create<RankOp>(loc, dimOp.getSource());
194+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
195+
builder.create<cf::AssertOp>(
196+
loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
197+
RuntimeVerifiableOpInterface::generateErrorMessage(
198+
op, "index is out of bounds"));
199+
}
200+
};
201+
175202
/// Verifies that the indices on load/store ops are in-bounds of the memref's
176203
/// index space: 0 <= index#i < dim#i
177204
template <typename LoadStoreOp>
@@ -192,19 +219,12 @@ struct LoadStoreOpInterface
192219
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
193220
Value assertCond;
194221
for (auto i : llvm::seq<int64_t>(0, rank)) {
195-
auto index = indices[i];
196-
197-
auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
198-
199-
auto geLow = builder.createOrFold<arith::CmpIOp>(
200-
loc, arith::CmpIPredicate::sge, index, zero);
201-
auto ltHigh = builder.createOrFold<arith::CmpIOp>(
202-
loc, arith::CmpIPredicate::slt, index, dimOp);
203-
auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
204-
222+
Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
223+
Value inBounds =
224+
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
205225
assertCond =
206-
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
207-
: andOp;
226+
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
227+
: inBounds;
208228
}
209229
builder.create<cf::AssertOp>(
210230
loc, assertCond,
@@ -380,6 +400,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
380400
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
381401
CastOp::attachInterface<CastOpInterface>(*ctx);
382402
CopyOp::attachInterface<CopyOpInterface>(*ctx);
403+
DimOp::attachInterface<DimOpInterface>(*ctx);
383404
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
384405
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
385406
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);

mlir/lib/Transforms/GenerateRuntimeVerification.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
2828
} // namespace
2929

3030
void GenerateRuntimeVerificationPass::runOnOperation() {
31+
// The implementation of the RuntimeVerifiableOpInterface may create ops that
32+
// can be verified. We don't want to generate verification for IR that
33+
// performs verification, so gather all runtime-verifiable ops first.
34+
SmallVector<RuntimeVerifiableOpInterface> ops;
3135
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
32-
OpBuilder builder(getOperation()->getContext());
36+
ops.push_back(verifiableOp);
37+
});
38+
39+
OpBuilder builder(getOperation()->getContext());
40+
for (RuntimeVerifiableOpInterface verifiableOp : ops) {
3341
builder.setInsertionPoint(verifiableOp);
3442
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
35-
});
43+
};
3644
}
3745

3846
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {

mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
1+
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -test-cf-assert \
3-
// RUN: -convert-func-to-llvm \
4-
// RUN: -convert-arith-to-llvm \
5-
// RUN: -reconcile-unrealized-casts | \
3+
// RUN: -expand-strided-metadata \
4+
// RUN: -convert-to-llvm | \
65
// RUN: mlir-runner -e main -entry-point-result=void \
76
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
87
// RUN: FileCheck %s
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -expand-strided-metadata \
3+
// RUN: -test-cf-assert \
4+
// RUN: -convert-to-llvm | \
5+
// RUN: mlir-runner -e main -entry-point-result=void \
6+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
7+
// RUN: FileCheck %s
8+
9+
func.func @main() {
10+
%c4 = arith.constant 4 : index
11+
%alloca = memref.alloca() : memref<1xf32>
12+
13+
// CHECK: ERROR: Runtime op verification failed
14+
// CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
15+
// CHECK-NEXT: ^ index is out of bounds
16+
// CHECK-NEXT: Location: loc({{.*}})
17+
%dim = memref.dim %alloca, %c4 : memref<1xf32>
18+
19+
return
20+
}

mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
2-
// RUN: -expand-strided-metadata \
3-
// RUN: -finalize-memref-to-llvm \
42
// RUN: -test-cf-assert \
5-
// RUN: -convert-func-to-llvm \
6-
// RUN: -convert-arith-to-llvm \
7-
// RUN: -reconcile-unrealized-casts | \
3+
// RUN: -expand-strided-metadata \
4+
// RUN: -lower-affine \
5+
// RUN: -convert-to-llvm | \
86
// RUN: mlir-runner -e main -entry-point-result=void \
97
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
108
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
2-
// RUN: -lower-affine \
3-
// RUN: -finalize-memref-to-llvm \
42
// RUN: -test-cf-assert \
5-
// RUN: -convert-func-to-llvm \
6-
// RUN: -convert-arith-to-llvm \
7-
// RUN: -reconcile-unrealized-casts | \
3+
// RUN: -expand-strided-metadata \
4+
// RUN: -lower-affine \
5+
// RUN: -convert-to-llvm | \
86
// RUN: mlir-runner -e main -entry-point-result=void \
97
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
108
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -test-cf-assert \
23
// RUN: -expand-strided-metadata \
34
// RUN: -lower-affine \
4-
// RUN: -finalize-memref-to-llvm \
5-
// RUN: -test-cf-assert \
6-
// RUN: -convert-func-to-llvm \
7-
// RUN: -convert-arith-to-llvm \
8-
// RUN: -reconcile-unrealized-casts | \
5+
// RUN: -convert-to-llvm | \
96
// RUN: mlir-runner -e main -entry-point-result=void \
107
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
118
// RUN: FileCheck %s

0 commit comments

Comments
 (0)