Skip to content

Commit ec54ec6

Browse files
bondhugulaVinayaka Bandishti
and
Vinayaka Bandishti
authored
[MLIR][Affine] Improve memref region bounding size and shape computation (llvm#129009)
Improve memref region utility (`getConstantBoundingSizeAndShape`) to get its constant bounding size and shape using affine expressions/maps by also considering local variables in the system. Leads to significantly precise and tighter bounding size and shape in the presence of div/mod expressions (as evident from the test cases). The approach is now more robust, proper, and complete. For affine fusion, this leads to private memrefs of accurate size in several cases. This also impacts other affine analysis-based passes like data copy generation that use memref regions. With contributions from `Vinayaka Bandishti <vinayaka@polymagelabs.com>` on `getConstantBoundingSizeAndShape` and getConstantBoundOnDimSize`. Fixes: llvm#46317 Co-authored-by: Vinayaka Bandishti <vinayaka@polymagelabs.com>
1 parent e3c8e17 commit ec54ec6

File tree

11 files changed

+377
-154
lines changed

11 files changed

+377
-154
lines changed

mlir/include/mlir/Analysis/FlatLinearValueConstraints.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,31 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
118118
/// we explicitly introduce them here.
119119
using IntegerPolyhedron::addBound;
120120

121+
/// Returns a non-negative constant bound on the extent (upper bound - lower
122+
/// bound) of the specified variable if it is found to be a constant; returns
123+
/// std::nullopt if it's not a constant. This method treats symbolic
124+
/// variables specially, i.e., it looks for constant differences between
125+
/// affine expressions involving only the symbolic variables. 'lb', if
126+
/// provided, is set to the lower bound map associated with the constant
127+
/// difference, and similarly, `ub` to the upper bound. Note that 'lb', 'ub'
128+
/// are purely symbolic and will correspond to the symbolic variables of the
129+
/// constaint set.
130+
// Egs: 0 <= i <= 15, return 16.
131+
// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
132+
// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
133+
// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
134+
// ceil(s0 - 7 / 8) = floor(s0 / 8)).
135+
/// The difference between this method and
136+
/// IntegerRelation::getConstantBoundOnDimSize is that unlike the latter, this
137+
/// makes use of affine expressions and maps in its inference and provides
138+
/// output with affine maps; it thus handles local variables by detecting them
139+
/// as affine functions of the symbols when possible.
140+
std::optional<int64_t>
141+
getConstantBoundOnDimSize(MLIRContext *context, unsigned pos,
142+
AffineMap *lb = nullptr, AffineMap *ub = nullptr,
143+
unsigned *minLbPos = nullptr,
144+
unsigned *minUbPos = nullptr) const;
145+
121146
/// Returns the constraint system as an integer set. Returns a null integer
122147
/// set if the system has no constraints, or if an integer set couldn't be
123148
/// constructed as a result of a local variable's explicit representation not

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ class IntegerRelation {
152152
/// intersection with no simplification of any sort attempted.
153153
void append(const IntegerRelation &other);
154154

155+
/// Finds an equality that equates the specified variable to a constant.
156+
/// Returns the position of the equality row. If 'symbolic' is set to true,
157+
/// symbols are also treated like a constant, i.e., an affine function of the
158+
/// symbols is also treated like a constant. Returns -1 if such an equality
159+
/// could not be found.
160+
int findEqualityToConstant(unsigned pos, bool symbolic = false) const;
161+
155162
/// Return the intersection of the two relations.
156163
/// If there are locals, they will be merged.
157164
IntegerRelation intersect(IntegerRelation other) const;

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,8 @@ struct MemRefRegion {
500500
/// to slice operands (which correspond to symbols).
501501
/// If 'addMemRefDimBounds' is true, constant upper/lower bounds
502502
/// [0, memref.getDimSize(i)) are added for each MemRef dimension 'i'.
503+
/// If `dropLocalVars` is true, all local variables in `cst` are projected
504+
/// out.
503505
///
504506
/// For example, the memref region for this operation at loopDepth = 1 will
505507
/// be:
@@ -513,9 +515,14 @@ struct MemRefRegion {
513515
/// {memref = %A, write = false, {%i <= m0 <= %i + 7} }
514516
/// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
515517
///
518+
/// If `dropOuterIVs` is true, project out any IVs other than those among
519+
/// `loopDepth` surrounding IVs, which would be symbols. If `dropOuterIVs`
520+
/// is false, the IVs would be turned into local variables instead of being
521+
/// projected out.
516522
LogicalResult compute(Operation *op, unsigned loopDepth,
517523
const ComputationSliceState *sliceState = nullptr,
518-
bool addMemRefDimBounds = true);
524+
bool addMemRefDimBounds = true,
525+
bool dropLocalVars = true, bool dropOuterIVs = true);
519526

520527
FlatAffineValueConstraints *getConstraints() { return &cst; }
521528
const FlatAffineValueConstraints *getConstraints() const { return &cst; }
@@ -530,31 +537,18 @@ struct MemRefRegion {
530537
/// corresponding dimension-wise bounds major to minor. The number of elements
531538
/// and all the dimension-wise bounds are guaranteed to be non-negative. We
532539
/// use int64_t instead of uint64_t since index types can be at most
533-
/// int64_t. `lbs` are set to the lower bounds for each of the rank
534-
/// dimensions, and lbDivisors contains the corresponding denominators for
535-
/// floorDivs.
540+
/// int64_t. `lbs` are set to the lower bound maps for each of the rank
541+
/// dimensions where each of these maps is purely symbolic in the constraints
542+
/// set's symbols.
536543
std::optional<int64_t> getConstantBoundingSizeAndShape(
537544
SmallVectorImpl<int64_t> *shape = nullptr,
538-
std::vector<SmallVector<int64_t, 4>> *lbs = nullptr,
539-
SmallVectorImpl<int64_t> *lbDivisors = nullptr) const;
545+
SmallVectorImpl<AffineMap> *lbs = nullptr) const;
540546

541547
/// Gets the lower and upper bound map for the dimensional variable at
542548
/// `pos`.
543549
void getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
544550
AffineMap &ubMap) const;
545551

546-
/// A wrapper around FlatAffineValueConstraints::getConstantBoundOnDimSize().
547-
/// 'pos' corresponds to the position of the memref shape's dimension (major
548-
/// to minor) which matches 1:1 with the dimensional variable positions in
549-
/// 'cst'.
550-
std::optional<int64_t>
551-
getConstantBoundOnDimSize(unsigned pos,
552-
SmallVectorImpl<int64_t> *lb = nullptr,
553-
int64_t *lbFloorDivisor = nullptr) const {
554-
assert(pos < getRank() && "invalid position");
555-
return cst.getConstantBoundOnDimSize64(pos, lb);
556-
}
557-
558552
/// Returns the size of this MemRefRegion in bytes.
559553
std::optional<int64_t> getRegionSize();
560554

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 226 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,49 @@ std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
581581
return {lbMap, ubMap};
582582
}
583583

584+
/// Express the pos^th identifier of `cst` as an affine expression in
585+
/// terms of other identifiers, if they are available in `exprs`, using the
586+
/// equality at position `idx` in `cs`t. Populates `exprs` with such an
587+
/// expression if possible, and return true. Returns false otherwise.
588+
static bool detectAsExpr(const FlatLinearConstraints &cst, unsigned pos,
589+
unsigned idx, MLIRContext *context,
590+
SmallVectorImpl<AffineExpr> &exprs) {
591+
// Initialize with a `0` expression.
592+
auto expr = getAffineConstantExpr(0, context);
593+
594+
// Traverse `idx`th equality and construct the possible affine expression in
595+
// terms of known identifiers.
596+
unsigned j, e;
597+
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
598+
if (j == pos)
599+
continue;
600+
int64_t c = cst.atEq64(idx, j);
601+
if (c == 0)
602+
continue;
603+
// If any of the involved IDs hasn't been found yet, we can't proceed.
604+
if (!exprs[j])
605+
break;
606+
expr = expr + exprs[j] * c;
607+
}
608+
if (j < e)
609+
// Can't construct expression as it depends on a yet uncomputed
610+
// identifier.
611+
return false;
612+
613+
// Add constant term to AffineExpr.
614+
expr = expr + cst.atEq64(idx, cst.getNumVars());
615+
int64_t vPos = cst.atEq64(idx, pos);
616+
assert(vPos != 0 && "expected non-zero here");
617+
if (vPos > 0)
618+
expr = (-expr).floorDiv(vPos);
619+
else
620+
// vPos < 0.
621+
expr = expr.floorDiv(-vPos);
622+
// Successfully constructed expression.
623+
exprs[pos] = expr;
624+
return true;
625+
}
626+
584627
/// Compute a representation of `num` identifiers starting at `offset` in `cst`
585628
/// as affine expressions involving other known identifiers. Each identifier's
586629
/// expression (in terms of known identifiers) is populated into `memo`.
@@ -636,41 +679,13 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
636679

637680
// Detect a variable as an expression of other variables.
638681
std::optional<unsigned> idx;
639-
if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true))) {
682+
if (!(idx = cst.findConstraintWithNonZeroAt(pos, /*isEq=*/true)))
640683
continue;
641-
}
642684

643-
// Build AffineExpr solving for variable 'pos' in terms of all others.
644-
auto expr = getAffineConstantExpr(0, context);
645-
unsigned j, e;
646-
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
647-
if (j == pos)
648-
continue;
649-
int64_t c = cst.atEq64(*idx, j);
650-
if (c == 0)
651-
continue;
652-
// If any of the involved IDs hasn't been found yet, we can't proceed.
653-
if (!memo[j])
654-
break;
655-
expr = expr + memo[j] * c;
656-
}
657-
if (j < e)
658-
// Can't construct expression as it depends on a yet uncomputed
659-
// variable.
685+
if (detectAsExpr(cst, pos, *idx, context, memo)) {
686+
changed = true;
660687
continue;
661-
662-
// Add constant term to AffineExpr.
663-
expr = expr + cst.atEq64(*idx, cst.getNumVars());
664-
int64_t vPos = cst.atEq64(*idx, pos);
665-
assert(vPos != 0 && "expected non-zero here");
666-
if (vPos > 0)
667-
expr = (-expr).floorDiv(vPos);
668-
else
669-
// vPos < 0.
670-
expr = expr.floorDiv(-vPos);
671-
// Successfully constructed expression.
672-
memo[pos] = expr;
673-
changed = true;
688+
}
674689
}
675690
// This loop is guaranteed to reach a fixed point - since once an
676691
// variable's explicit form is computed (in memo[pos]), it's not updated
@@ -891,6 +906,185 @@ FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
891906
llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
892907
}
893908

909+
/// Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
910+
/// at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
911+
/// `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
912+
/// the equality/inequality contains any unknown id, return None. Otherwise
913+
/// return sum as `AffineExpr`.
914+
static std::optional<AffineExpr> getAsExpr(const FlatLinearConstraints &cst,
915+
unsigned pos, MLIRContext *context,
916+
ArrayRef<AffineExpr> exprs,
917+
unsigned idx, bool isEquality) {
918+
// Initialize with a `0` expression.
919+
auto expr = getAffineConstantExpr(0, context);
920+
921+
SmallVector<int64_t, 8> row =
922+
isEquality ? cst.getEquality64(idx) : cst.getInequality64(idx);
923+
924+
// Traverse `idx`th equality and construct the possible affine expression in
925+
// terms of known identifiers.
926+
unsigned j, e;
927+
for (j = 0, e = cst.getNumVars(); j < e; ++j) {
928+
if (j == pos)
929+
continue;
930+
int64_t c = row[j];
931+
if (c == 0)
932+
continue;
933+
// If any of the involved IDs hasn't been found yet, we can't proceed.
934+
if (!exprs[j])
935+
break;
936+
expr = expr + exprs[j] * c;
937+
}
938+
if (j < e)
939+
// Can't construct expression as it depends on a yet uncomputed
940+
// identifier.
941+
return std::nullopt;
942+
943+
// Add constant term to AffineExpr.
944+
expr = expr + row[cst.getNumVars()];
945+
return expr;
946+
}
947+
948+
std::optional<int64_t> FlatLinearConstraints::getConstantBoundOnDimSize(
949+
MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
950+
unsigned *minLbPos, unsigned *minUbPos) const {
951+
952+
assert(pos < getNumDimVars() && "Invalid identifier position");
953+
954+
auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t> cst,
955+
ArrayRef<AffineExpr> whiteListCols) {
956+
for (int i = getNumDimAndSymbolVars(), e = cst.size() - 1; i < e; ++i) {
957+
if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant())
958+
continue;
959+
if (cst[i] != 0)
960+
return false;
961+
}
962+
return true;
963+
};
964+
965+
// Detect the necesary local variables first.
966+
SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
967+
(void)computeLocalVars(memo, context);
968+
969+
// Find an equality for 'pos'^th identifier that equates it to some function
970+
// of the symbolic identifiers (+ constant).
971+
int eqPos = findEqualityToConstant(pos, /*symbolic=*/true);
972+
// If the equality involves a local var that can not be expressed as a
973+
// symbolic or constant affine expression, we bail out.
974+
if (eqPos != -1 && freeOfUnknownLocalVars(getEquality64(eqPos), memo)) {
975+
// This identifier can only take a single value.
976+
if (lb && detectAsExpr(*this, pos, eqPos, context, memo)) {
977+
AffineExpr equalityExpr =
978+
simplifyAffineExpr(memo[pos], 0, getNumSymbolVars());
979+
*lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), equalityExpr);
980+
if (ub)
981+
*ub = *lb;
982+
}
983+
if (minLbPos)
984+
*minLbPos = eqPos;
985+
if (minUbPos)
986+
*minUbPos = eqPos;
987+
return 1;
988+
}
989+
990+
// Positions of constraints that are lower/upper bounds on the variable.
991+
SmallVector<unsigned, 4> lbIndices, ubIndices;
992+
993+
// Note inequalities that give lower and upper bounds.
994+
getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices,
995+
/*eqIndices=*/nullptr, /*offset=*/0,
996+
/*num=*/getNumDimVars());
997+
998+
std::optional<int64_t> minDiff = std::nullopt;
999+
unsigned minLbPosition = 0, minUbPosition = 0;
1000+
AffineExpr minLbExpr, minUbExpr;
1001+
1002+
// Traverse each lower bound and upper bound pair, to compute the difference
1003+
// between them.
1004+
for (unsigned ubPos : ubIndices) {
1005+
// Construct sum of all ids other than `pos`th in the given upper bound row.
1006+
std::optional<AffineExpr> maybeUbExpr =
1007+
getAsExpr(*this, pos, context, memo, ubPos, /*isEquality=*/false);
1008+
if (!maybeUbExpr.has_value() || !(*maybeUbExpr).isSymbolicOrConstant())
1009+
continue;
1010+
1011+
// Canonical form of an inequality that constrains the upper bound on
1012+
// an id `x_i` is of the form:
1013+
// `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
1014+
// Therefore the upper bound on `x_i` will be
1015+
// `(
1016+
// sum(c_j*x_j) where j != i
1017+
// +
1018+
// c_0
1019+
// )
1020+
// /
1021+
// -(c_i)`. Divison here is a floorDiv.
1022+
AffineExpr ubExpr = maybeUbExpr->floorDiv(-atIneq64(ubPos, pos));
1023+
assert(-atIneq64(ubPos, pos) > 0 && "invalid upper bound index");
1024+
1025+
// Go over each lower bound.
1026+
for (unsigned lbPos : lbIndices) {
1027+
// Construct sum of all ids other than `pos`th in the given lower bound
1028+
// row.
1029+
std::optional<AffineExpr> maybeLbExpr =
1030+
getAsExpr(*this, pos, context, memo, lbPos, /*isEquality=*/false);
1031+
if (!maybeLbExpr.has_value() || !(*maybeLbExpr).isSymbolicOrConstant())
1032+
continue;
1033+
1034+
// Canonical form of an inequality that is constraining the lower bound
1035+
// on an id `x_i is of the form:
1036+
// `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
1037+
// Therefore upperBound on `x_i` will be
1038+
// `-(
1039+
// sum(c_j*x_j) where j != i
1040+
// +
1041+
// c_0
1042+
// )
1043+
// /
1044+
// c_i`. Divison here is a ceilDiv.
1045+
int64_t divisor = atIneq64(lbPos, pos);
1046+
// We convert the `ceilDiv` for floordiv with the formula:
1047+
// `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
1048+
// since uniformly keeping divisons as `floorDiv` helps their
1049+
// simplification.
1050+
AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1).floorDiv(divisor);
1051+
assert(atIneq64(lbPos, pos) > 0 && "invalid lower bound index");
1052+
1053+
AffineExpr difference =
1054+
simplifyAffineExpr(ubExpr - lbExpr + 1, 0, getNumSymbolVars());
1055+
// If the difference is not constant, ignore the lower bound - upper bound
1056+
// pair.
1057+
auto constantDiff = dyn_cast<AffineConstantExpr>(difference);
1058+
if (!constantDiff)
1059+
continue;
1060+
1061+
int64_t diffValue = constantDiff.getValue();
1062+
// This bound is non-negative by definition.
1063+
diffValue = std::max<int64_t>(diffValue, 0);
1064+
if (!minDiff || diffValue < *minDiff) {
1065+
minDiff = diffValue;
1066+
minLbPosition = lbPos;
1067+
minUbPosition = ubPos;
1068+
minLbExpr = lbExpr;
1069+
minUbExpr = ubExpr;
1070+
}
1071+
}
1072+
}
1073+
1074+
// Populate outputs where available and needed.
1075+
if (lb && minDiff) {
1076+
*lb = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minLbExpr);
1077+
}
1078+
if (ub)
1079+
*ub = AffineMap::get(/*dimCount=*/0, getNumSymbolVars(), minUbExpr);
1080+
if (minLbPos)
1081+
*minLbPos = minLbPosition;
1082+
if (minUbPos)
1083+
*minUbPos = minUbPosition;
1084+
1085+
return minDiff;
1086+
}
1087+
8941088
IntegerSet FlatLinearConstraints::getAsIntegerSet(MLIRContext *context) const {
8951089
if (getNumConstraints() == 0)
8961090
// Return universal set (always true): 0 == 0.

0 commit comments

Comments
 (0)