Skip to content

Commit f54cdc5

Browse files
authored
[mlir] IntegerRangeAnalysis: add support for vector type (llvm#112292)
Treat integer range for vector type as union of ranges of individual elements. With this semantics, most arith ops on vectors will work out of the box, the only special handling needed for constants and vector elements manipulation ops. The end goal of these changes is to be able to optimize vectorized index calculations.
1 parent 17bad1a commit f54cdc5

File tree

8 files changed

+211
-18
lines changed

8 files changed

+211
-18
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@
1313
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1414
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
1515

16-
include "mlir/Dialect/Vector/IR/Vector.td"
17-
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
1816
include "mlir/Dialect/Arith/IR/ArithBase.td"
1917
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
2018
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
2119
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
22-
include "mlir/IR/EnumAttr.td"
20+
include "mlir/Dialect/Vector/IR/Vector.td"
21+
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
2322
include "mlir/Interfaces/ControlFlowInterfaces.td"
2423
include "mlir/Interfaces/DestinationStyleOpInterface.td"
24+
include "mlir/Interfaces/InferIntRangeInterface.td"
2525
include "mlir/Interfaces/InferTypeOpInterface.td"
2626
include "mlir/Interfaces/SideEffectInterfaces.td"
2727
include "mlir/Interfaces/VectorInterfaces.td"
2828
include "mlir/Interfaces/ViewLikeInterface.td"
2929
include "mlir/IR/BuiltinAttributes.td"
30+
include "mlir/IR/EnumAttr.td"
3031

3132
// TODO: Add an attribute to specify a different algebra with operators other
3233
// than the current set: {*, +}.
@@ -346,6 +347,7 @@ def Vector_MultiDimReductionOp :
346347

347348
def Vector_BroadcastOp :
348349
Vector_Op<"broadcast", [Pure,
350+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
349351
PredOpTrait<"source operand and result have same element type",
350352
TCresVTEtIsSameAsOpBase<0, 0>>]>,
351353
Arguments<(ins AnyType:$source)>,
@@ -627,6 +629,7 @@ def Vector_DeinterleaveOp :
627629

628630
def Vector_ExtractElementOp :
629631
Vector_Op<"extractelement", [Pure,
632+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
630633
TypesMatchWith<"result type matches element type of vector operand",
631634
"vector", "result",
632635
"::llvm::cast<VectorType>($_self).getElementType()">]>,
@@ -673,6 +676,7 @@ def Vector_ExtractElementOp :
673676

674677
def Vector_ExtractOp :
675678
Vector_Op<"extract", [Pure,
679+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
676680
PredOpTrait<"operand and result have same element type",
677681
TCresVTEtIsSameAsOpBase<0, 0>>,
678682
InferTypeOpAdaptorWithIsCompatible]> {
@@ -810,6 +814,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
810814

811815
def Vector_InsertElementOp :
812816
Vector_Op<"insertelement", [Pure,
817+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
813818
TypesMatchWith<"source operand type matches element type of result",
814819
"result", "source",
815820
"::llvm::cast<VectorType>($_self).getElementType()">,
@@ -858,6 +863,7 @@ def Vector_InsertElementOp :
858863

859864
def Vector_InsertOp :
860865
Vector_Op<"insert", [Pure,
866+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
861867
PredOpTrait<"source operand and result have same element type",
862868
TCresVTEtIsSameAsOpBase<0, 0>>,
863869
AllTypesMatch<["dest", "result"]>]> {
@@ -2204,7 +2210,9 @@ def Vector_CompressStoreOp :
22042210
}
22052211

22062212
def Vector_ShapeCastOp :
2207-
Vector_Op<"shape_cast", [Pure]>,
2213+
Vector_Op<"shape_cast", [Pure,
2214+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
2215+
]>,
22082216
Arguments<(ins AnyVectorOfAnyRank:$source)>,
22092217
Results<(outs AnyVectorOfAnyRank:$result)> {
22102218
let summary = "shape_cast casts between vector shapes";
@@ -2801,6 +2809,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
28012809

28022810
def Vector_SplatOp : Vector_Op<"splat", [
28032811
Pure,
2812+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
28042813
TypesMatchWith<"operand type matches element type of result",
28052814
"aggregate", "input",
28062815
"::llvm::cast<VectorType>($_self).getElementType()">

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/Dialect.h"
2121
#include "mlir/IR/OpDefinition.h"
22+
#include "mlir/IR/TypeUtilities.h"
2223
#include "mlir/IR/Value.h"
2324
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2425
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
5354
dialect = parent->getDialect();
5455
else
5556
dialect = value.getParentBlock()->getParentOp()->getDialect();
57+
58+
Type type = getElementTypeOrSelf(value);
5659
solver->propagateIfChanged(
57-
cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
58-
dialect)));
60+
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
5961
}
6062

6163
LogicalResult IntegerRangeAnalysis::visitOperation(

mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,22 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) {
3535

3636
void arith::ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3737
SetIntRangeFn setResultRange) {
38-
auto constAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue());
39-
if (constAttr) {
40-
const APInt &value = constAttr.getValue();
38+
if (auto scalarCstAttr = llvm::dyn_cast_or_null<IntegerAttr>(getValue())) {
39+
const APInt &value = scalarCstAttr.getValue();
4140
setResultRange(getResult(), ConstantIntRanges::constant(value));
41+
return;
42+
}
43+
if (auto arrayCstAttr =
44+
llvm::dyn_cast_or_null<DenseIntElementsAttr>(getValue())) {
45+
std::optional<ConstantIntRanges> result;
46+
for (const APInt &val : arrayCstAttr) {
47+
auto range = ConstantIntRanges::constant(val);
48+
result = (result ? result->rangeUnion(range) : range);
49+
}
50+
51+
assert(result && "Zero-sized vectors are not allowed");
52+
setResultRange(getResult(), *result);
53+
return;
4254
}
4355
}
4456

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
5151
if (!maybeConstValue.has_value())
5252
return failure();
5353

54+
Type type = value.getType();
55+
Location loc = value.getLoc();
5456
Operation *maybeDefiningOp = value.getDefiningOp();
5557
Dialect *valueDialect =
5658
maybeDefiningOp ? maybeDefiningOp->getDialect()
5759
: value.getParentRegion()->getParentOp()->getDialect();
58-
Attribute constAttr =
59-
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
60-
Operation *constOp = valueDialect->materializeConstant(
61-
rewriter, constAttr, value.getType(), value.getLoc());
60+
61+
Attribute constAttr;
62+
if (auto shaped = dyn_cast<ShapedType>(type)) {
63+
constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
64+
} else {
65+
constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
66+
}
67+
Operation *constOp =
68+
valueDialect->materializeConstant(rewriter, constAttr, type, loc);
6269
// Fall back to arith.constant if the dialect materializer doesn't know what
6370
// to do with an integer constant.
6471
if (!constOp)
6572
constOp = rewriter.getContext()
6673
->getLoadedDialect<ArithDialect>()
67-
->materializeConstant(rewriter, constAttr, value.getType(),
68-
value.getLoc());
74+
->materializeConstant(rewriter, constAttr, type, loc);
6975
if (!constOp)
7076
return failure();
7177

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
12211221
// ExtractElementOp
12221222
//===----------------------------------------------------------------------===//
12231223

1224+
void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1225+
SetIntRangeFn setResultRanges) {
1226+
setResultRanges(getResult(), argRanges.front());
1227+
}
1228+
12241229
void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
12251230
Value source) {
12261231
result.addOperands({source});
@@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
12731278
// ExtractOp
12741279
//===----------------------------------------------------------------------===//
12751280

1281+
void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1282+
SetIntRangeFn setResultRanges) {
1283+
setResultRanges(getResult(), argRanges.front());
1284+
}
1285+
12761286
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
12771287
Value source, int64_t position) {
12781288
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2252,6 +2262,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
22522262
// BroadcastOp
22532263
//===----------------------------------------------------------------------===//
22542264

2265+
void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2266+
SetIntRangeFn setResultRanges) {
2267+
setResultRanges(getResult(), argRanges.front());
2268+
}
2269+
22552270
/// Return the dimensions of the result vector that were formerly ones in the
22562271
/// source tensor and thus correspond to "dim-1" broadcasting.
22572272
static llvm::SetVector<int64_t>
@@ -2713,6 +2728,11 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
27132728
// InsertElementOp
27142729
//===----------------------------------------------------------------------===//
27152730

2731+
void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2732+
SetIntRangeFn setResultRanges) {
2733+
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2734+
}
2735+
27162736
void InsertElementOp::build(OpBuilder &builder, OperationState &result,
27172737
Value source, Value dest) {
27182738
build(builder, result, source, dest, {});
@@ -2762,6 +2782,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
27622782
// InsertOp
27632783
//===----------------------------------------------------------------------===//
27642784

2785+
void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2786+
SetIntRangeFn setResultRanges) {
2787+
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
2788+
}
2789+
27652790
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
27662791
Value source, Value dest, int64_t position) {
27672792
build(builder, result, source, dest, ArrayRef<int64_t>{position});
@@ -5277,6 +5302,11 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
52775302
// ShapeCastOp
52785303
//===----------------------------------------------------------------------===//
52795304

5305+
void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
5306+
SetIntRangeFn setResultRanges) {
5307+
setResultRanges(getResult(), argRanges.front());
5308+
}
5309+
52805310
/// Returns true if each element of 'a' is equal to the product of a contiguous
52815311
/// sequence of the elements of 'b'. Returns false otherwise.
52825312
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
@@ -6423,6 +6453,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
64236453
return SplatElementsAttr::get(getType(), {constOperand});
64246454
}
64256455

6456+
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6457+
SetIntRangeFn setResultRanges) {
6458+
setResultRanges(getResult(), argRanges.front());
6459+
}
6460+
64266461
//===----------------------------------------------------------------------===//
64276462
// WarpExecuteOnLane0Op
64286463
//===----------------------------------------------------------------------===//
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s
2+
3+
4+
// CHECK-LABEL: func @constant_vec
5+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index}
6+
func.func @constant_vec() -> vector<8xindex> {
7+
%0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
8+
%1 = test.reflect_bounds %0 : vector<8xindex>
9+
func.return %1 : vector<8xindex>
10+
}
11+
12+
// CHECK-LABEL: func @constant_splat
13+
// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32}
14+
func.func @constant_splat() -> vector<8xi32> {
15+
%0 = arith.constant dense<3> : vector<8xi32>
16+
%1 = test.reflect_bounds %0 : vector<8xi32>
17+
func.return %1 : vector<8xi32>
18+
}
19+
20+
// CHECK-LABEL: func @vector_splat
21+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
22+
func.func @vector_splat() -> vector<4xindex> {
23+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
24+
%1 = vector.splat %0 : vector<4xindex>
25+
%2 = test.reflect_bounds %1 : vector<4xindex>
26+
func.return %2 : vector<4xindex>
27+
}
28+
29+
// CHECK-LABEL: func @vector_broadcast
30+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
31+
func.func @vector_broadcast() -> vector<4x16xindex> {
32+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
33+
%1 = vector.broadcast %0 : vector<16xindex> to vector<4x16xindex>
34+
%2 = test.reflect_bounds %1 : vector<4x16xindex>
35+
func.return %2 : vector<4x16xindex>
36+
}
37+
38+
// CHECK-LABEL: func @vector_shape_cast
39+
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
40+
func.func @vector_shape_cast() -> vector<4x4xindex> {
41+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex>
42+
%1 = vector.shape_cast %0 : vector<16xindex> to vector<4x4xindex>
43+
%2 = test.reflect_bounds %1 : vector<4x4xindex>
44+
func.return %2 : vector<4x4xindex>
45+
}
46+
47+
// CHECK-LABEL: func @vector_extract
48+
// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index}
49+
func.func @vector_extract() -> index {
50+
%0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex>
51+
%1 = vector.extract %0[0] : index from vector<4xindex>
52+
%2 = test.reflect_bounds %1 : index
53+
func.return %2 : index
54+
}
55+
56+
// CHECK-LABEL: func @vector_extractelement
57+
// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index}
58+
func.func @vector_extractelement() -> index {
59+
%c0 = arith.constant 0 : index
60+
%0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
61+
%1 = vector.extractelement %0[%c0 : index] : vector<4xindex>
62+
%2 = test.reflect_bounds %1 : index
63+
func.return %2 : index
64+
}
65+
66+
// CHECK-LABEL: func @vector_add
67+
// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index}
68+
func.func @vector_add() -> vector<4xindex> {
69+
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex>
70+
%1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex>
71+
%2 = arith.addi %0, %1 : vector<4xindex>
72+
%3 = test.reflect_bounds %2 : vector<4xindex>
73+
func.return %3 : vector<4xindex>
74+
}
75+
76+
// CHECK-LABEL: func @vector_insert
77+
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
78+
func.func @vector_insert() -> vector<4xindex> {
79+
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
80+
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
81+
%2 = vector.insert %1, %0[0] : index into vector<4xindex>
82+
%3 = test.reflect_bounds %2 : vector<4xindex>
83+
func.return %3 : vector<4xindex>
84+
}
85+
86+
// CHECK-LABEL: func @vector_insertelement
87+
// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index}
88+
func.func @vector_insertelement() -> vector<4xindex> {
89+
%c0 = arith.constant 0 : index
90+
%0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex>
91+
%1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index
92+
%2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex>
93+
%3 = test.reflect_bounds %2 : vector<4xindex>
94+
func.return %3 : vector<4xindex>
95+
}
96+
97+
// CHECK-LABEL: func @test_loaded_vector_extract
98+
// No bounds
99+
// CHECK: test.reflect_bounds %{{.*}} : i32
100+
func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 {
101+
%c0 = arith.constant 0 : index
102+
%v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32>
103+
%e = vector.extract %v[0] : i32 from vector<4xi32>
104+
%bounds = test.reflect_bounds %e : i32
105+
func.return %bounds : i32
106+
}

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges(
760760
Type sIntTy, uIntTy;
761761
// For plain `IntegerType`s, we can derive the appropriate signed and unsigned
762762
// Types for the Attributes.
763-
if (auto intTy = llvm::dyn_cast<IntegerType>(getType())) {
763+
Type type = getElementTypeOrSelf(getType());
764+
if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
764765
unsigned bitwidth = intTy.getWidth();
765766
sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
766767
uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
767768
} else
768-
sIntTy = uIntTy = getType();
769+
sIntTy = uIntTy = type;
769770

770771
setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
771772
setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));

0 commit comments

Comments
 (0)