Skip to content

Commit 96eeb6c

Browse files
authored
[mlir][llvm] Support nusw and nuw in GEP (llvm#137272)
nusw and nuw were introduced in getelementptr, this patch plumbs them in MLIR. Since inbounds implies nusw, this patch also adds an inboundsFlag to represent the concept of raw inbounds with no nusw implication, and have the inbounds literal captured as the combination of inboundsFlag and nusw. Fixes: iree#20482 Signed-off-by: Lin, Peiyong <linpyong@gmail.com>
1 parent 857ac4c commit 96eeb6c

File tree

9 files changed

+105
-19
lines changed

9 files changed

+105
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,4 +876,32 @@ def UWTableKindEnum : LLVM_EnumAttr<
876876
let cppNamespace = "::mlir::LLVM::uwtable";
877877
}
878878

879+
//===----------------------------------------------------------------------===//
880+
// GEPNoWrapFlags
881+
//===----------------------------------------------------------------------===//
882+
883+
// These values must match llvm::GEPNoWrapFlags ones.
884+
// See llvm/include/llvm/IR/GEPNoWrapFlags.h.
885+
// Since inbounds implies nusw, create an inboundsFlag that represents the
886+
// concept of raw inbounds with no nusw implication and the actual inbounds
887+
// literal will be captured as the combination of inboundsFlag and nusw.
888+
889+
def GEPNone : I32BitEnumCaseNone<"none">;
890+
def GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
891+
def GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
892+
def GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
893+
def GEPInbounds : BitEnumCaseGroup<"inbounds", [GEPInboundsFlag, GEPNusw]>;
894+
895+
def GEPNoWrapFlags : I32BitEnum<
896+
"GEPNoWrapFlags",
897+
"::mlir::LLVM::GEPNoWrapFlags",
898+
[GEPNone, GEPInboundsFlag, GEPNusw, GEPNuw, GEPInbounds]> {
899+
let cppNamespace = "::mlir::LLVM";
900+
let printBitEnumPrimaryGroups = 1;
901+
}
902+
903+
def GEPNoWrapFlagsProp : EnumProp<GEPNoWrapFlags> {
904+
let defaultValue = interfaceType # "::none";
905+
}
906+
879907
#endif // LLVMIR_ENUMS

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
291291
Variadic<LLVM_ScalarOrVectorOf<AnySignlessInteger>>:$dynamicIndices,
292292
DenseI32ArrayAttr:$rawConstantIndices,
293293
TypeAttr:$elem_type,
294-
UnitAttr:$inbounds);
294+
GEPNoWrapFlagsProp:$noWrapFlags);
295295
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
296296
let skipDefaultBuilders = 1;
297297

@@ -303,8 +303,12 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
303303
as indices. In the case of indexing within a structure, it is required to
304304
either use constant indices directly, or supply a constant SSA value.
305305

306-
An optional 'inbounds' attribute specifies the low-level pointer arithmetic
306+
The no-wrap flags can be used to specify the low-level pointer arithmetic
307307
overflow behavior that LLVM uses after lowering the operation to LLVM IR.
308+
Valid options include 'inbounds' (pointer arithmetic must be within object
309+
bounds), 'nusw' (no unsigned signed wrap), and 'nuw' (no unsigned wrap).
310+
Note that 'inbounds' implies 'nusw' which is ensured by the enum
311+
definition. The flags can be set individually or in combination.
308312

309313
Examples:
310314

@@ -323,10 +327,12 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
323327

324328
let builders = [
325329
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
326-
"ValueRange":$indices, CArg<"bool", "false">:$inbounds,
330+
"ValueRange":$indices,
331+
CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags,
327332
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
328333
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
329-
"ArrayRef<GEPArg>":$indices, CArg<"bool", "false">:$inbounds,
334+
"ArrayRef<GEPArg>":$indices,
335+
CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags,
330336
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
331337
];
332338
let llvmBuilder = [{
@@ -343,10 +349,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
343349
}
344350
Type baseElementType = op.getElemType();
345351
llvm::Type *elementType = moduleTranslation.convertType(baseElementType);
346-
$res = builder.CreateGEP(elementType, $base, indices, "", $inbounds);
352+
$res = builder.CreateGEP(elementType, $base, indices, "",
353+
llvm::GEPNoWrapFlags::fromRaw(
354+
static_cast<unsigned>(
355+
op.getNoWrapFlags())));
347356
}];
348357
let assemblyFormat = [{
349-
(`inbounds` $inbounds^)?
358+
($noWrapFlags^)?
350359
$base `[` custom<GEPIndices>($dynamicIndices, $rawConstantIndices) `]` attr-dict
351360
`:` functional-type(operands, results) `,` $elem_type
352361
}];

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -673,29 +673,29 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
673673

674674
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
675675
Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
676-
bool inbounds, ArrayRef<NamedAttribute> attributes) {
676+
GEPNoWrapFlags noWrapFlags,
677+
ArrayRef<NamedAttribute> attributes) {
677678
SmallVector<int32_t> rawConstantIndices;
678679
SmallVector<Value> dynamicIndices;
679680
destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
680681

681682
result.addTypes(resultType);
682683
result.addAttributes(attributes);
683-
result.addAttribute(getRawConstantIndicesAttrName(result.name),
684-
builder.getDenseI32ArrayAttr(rawConstantIndices));
685-
if (inbounds) {
686-
result.addAttribute(getInboundsAttrName(result.name),
687-
builder.getUnitAttr());
688-
}
689-
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
684+
result.getOrAddProperties<Properties>().rawConstantIndices =
685+
builder.getDenseI32ArrayAttr(rawConstantIndices);
686+
result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
687+
result.getOrAddProperties<Properties>().elem_type =
688+
TypeAttr::get(elementType);
690689
result.addOperands(basePtr);
691690
result.addOperands(dynamicIndices);
692691
}
693692

694693
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
695694
Type elementType, Value basePtr, ValueRange indices,
696-
bool inbounds, ArrayRef<NamedAttribute> attributes) {
695+
GEPNoWrapFlags noWrapFlags,
696+
ArrayRef<NamedAttribute> attributes) {
697697
build(builder, result, resultType, elementType, basePtr,
698-
SmallVector<GEPArg>(indices), inbounds, attributes);
698+
SmallVector<GEPArg>(indices), noWrapFlags, attributes);
699699
}
700700

701701
static ParseResult
@@ -794,6 +794,9 @@ LogicalResult LLVM::GEPOp::verify() {
794794
return emitOpError("expected as many dynamic indices as specified in '")
795795
<< getRawConstantIndicesAttrName().getValue() << "'";
796796

797+
if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
798+
return emitOpError("'inbounds_flag' cannot be used directly.");
799+
797800
return verifyStructIndices(getElemType(), getIndices(),
798801
[&] { return emitOpError(); });
799802
}

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
891891
auto byteType = IntegerType::get(builder.getContext(), 8);
892892
auto newPtr = builder.createOrFold<LLVM::GEPOp>(
893893
getLoc(), getResult().getType(), byteType, newSlot.ptr,
894-
ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
894+
ArrayRef<GEPArg>(accessInfo->subslotOffset), getNoWrapFlags());
895895
getResult().replaceAllUsesWith(newPtr);
896896
return DeletionKind::Delete;
897897
}

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,8 +2035,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
20352035
}
20362036

20372037
Type type = convertType(inst->getType());
2038-
auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
2039-
indices, gepInst->isInBounds());
2038+
auto gepOp = builder.create<GEPOp>(
2039+
loc, type, sourceElementType, *basePtr, indices,
2040+
static_cast<GEPNoWrapFlags>(gepInst->getNoWrapFlags().getRaw()));
20402041
mapValue(inst, gepOp);
20412042
return success();
20422043
}

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,3 +1819,11 @@ llvm.func @t1() -> !llvm.ptr {
18191819
^bb1:
18201820
llvm.return %0 : !llvm.ptr
18211821
}
1822+
1823+
// -----
1824+
1825+
llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
1826+
// expected-error@+1 {{'inbounds_flag' cannot be used directly}}
1827+
llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1828+
llvm.return
1829+
}

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,16 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64, %ptr2: !llvm.ptr) {
236236
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>
237237
// CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
238238
llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
239+
// CHECK: llvm.getelementptr inbounds|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
240+
llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
241+
// CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
242+
llvm.getelementptr inbounds | nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
243+
// CHECK: llvm.getelementptr nusw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
244+
llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
245+
// CHECK: llvm.getelementptr nusw|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
246+
llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
247+
// CHECK: llvm.getelementptr nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
248+
llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
239249
llvm.return
240250
}
241251

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,25 @@ define void @gep_static_idx(ptr %ptr) {
557557

558558
; // -----
559559

560+
; CHECK-LABEL: @gep_no_wrap_flags
561+
; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
562+
define void @gep_no_wrap_flags(ptr %ptr) {
563+
; CHECK: %[[IDX:.+]] = llvm.mlir.constant(7 : i32)
564+
; CHECK: llvm.getelementptr inbounds %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
565+
%1 = getelementptr inbounds float, ptr %ptr, i32 7
566+
; CHECK: llvm.getelementptr nusw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
567+
%2 = getelementptr nusw float, ptr %ptr, i32 7
568+
; CHECK: llvm.getelementptr nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
569+
%3 = getelementptr nuw float, ptr %ptr, i32 7
570+
; CHECK: llvm.getelementptr nusw|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
571+
%4 = getelementptr nusw nuw float, ptr %ptr, i32 7
572+
; CHECK: llvm.getelementptr inbounds|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
573+
%5 = getelementptr inbounds nuw float, ptr %ptr, i32 7
574+
ret void
575+
}
576+
577+
; // -----
578+
560579
; CHECK: @varargs(...)
561580
declare void @varargs(...)
562581

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,14 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64,
10571057
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>
10581058
// CHECK: = getelementptr inbounds { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
10591059
llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1060+
// CHECK: = getelementptr inbounds nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1061+
llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1062+
// CHECK: = getelementptr nusw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1063+
llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1064+
// CHECK: = getelementptr nusw nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1065+
llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1066+
// CHECK: = getelementptr nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1067+
llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
10601068
llvm.return
10611069
}
10621070

0 commit comments

Comments
 (0)