@@ -581,6 +581,49 @@ std::pair<AffineMap, AffineMap> FlatLinearConstraints::getLowerAndUpperBound(
581
581
return {lbMap, ubMap};
582
582
}
583
583
584
+ // / Express the pos^th identifier of `cst` as an affine expression in
585
+ // / terms of other identifiers, if they are available in `exprs`, using the
586
+ // / equality at position `idx` in `cs`t. Populates `exprs` with such an
587
+ // / expression if possible, and return true. Returns false otherwise.
588
+ static bool detectAsExpr (const FlatLinearConstraints &cst, unsigned pos,
589
+ unsigned idx, MLIRContext *context,
590
+ SmallVectorImpl<AffineExpr> &exprs) {
591
+ // Initialize with a `0` expression.
592
+ auto expr = getAffineConstantExpr (0 , context);
593
+
594
+ // Traverse `idx`th equality and construct the possible affine expression in
595
+ // terms of known identifiers.
596
+ unsigned j, e;
597
+ for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
598
+ if (j == pos)
599
+ continue ;
600
+ int64_t c = cst.atEq64 (idx, j);
601
+ if (c == 0 )
602
+ continue ;
603
+ // If any of the involved IDs hasn't been found yet, we can't proceed.
604
+ if (!exprs[j])
605
+ break ;
606
+ expr = expr + exprs[j] * c;
607
+ }
608
+ if (j < e)
609
+ // Can't construct expression as it depends on a yet uncomputed
610
+ // identifier.
611
+ return false ;
612
+
613
+ // Add constant term to AffineExpr.
614
+ expr = expr + cst.atEq64 (idx, cst.getNumVars ());
615
+ int64_t vPos = cst.atEq64 (idx, pos);
616
+ assert (vPos != 0 && " expected non-zero here" );
617
+ if (vPos > 0 )
618
+ expr = (-expr).floorDiv (vPos);
619
+ else
620
+ // vPos < 0.
621
+ expr = expr.floorDiv (-vPos);
622
+ // Successfully constructed expression.
623
+ exprs[pos] = expr;
624
+ return true ;
625
+ }
626
+
584
627
// / Compute a representation of `num` identifiers starting at `offset` in `cst`
585
628
// / as affine expressions involving other known identifiers. Each identifier's
586
629
// / expression (in terms of known identifiers) is populated into `memo`.
@@ -636,41 +679,13 @@ static void computeUnknownVars(const FlatLinearConstraints &cst,
636
679
637
680
// Detect a variable as an expression of other variables.
638
681
std::optional<unsigned > idx;
639
- if (!(idx = cst.findConstraintWithNonZeroAt (pos, /* isEq=*/ true ))) {
682
+ if (!(idx = cst.findConstraintWithNonZeroAt (pos, /* isEq=*/ true )))
640
683
continue ;
641
- }
642
684
643
- // Build AffineExpr solving for variable 'pos' in terms of all others.
644
- auto expr = getAffineConstantExpr (0 , context);
645
- unsigned j, e;
646
- for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
647
- if (j == pos)
648
- continue ;
649
- int64_t c = cst.atEq64 (*idx, j);
650
- if (c == 0 )
651
- continue ;
652
- // If any of the involved IDs hasn't been found yet, we can't proceed.
653
- if (!memo[j])
654
- break ;
655
- expr = expr + memo[j] * c;
656
- }
657
- if (j < e)
658
- // Can't construct expression as it depends on a yet uncomputed
659
- // variable.
685
+ if (detectAsExpr (cst, pos, *idx, context, memo)) {
686
+ changed = true ;
660
687
continue ;
661
-
662
- // Add constant term to AffineExpr.
663
- expr = expr + cst.atEq64 (*idx, cst.getNumVars ());
664
- int64_t vPos = cst.atEq64 (*idx, pos);
665
- assert (vPos != 0 && " expected non-zero here" );
666
- if (vPos > 0 )
667
- expr = (-expr).floorDiv (vPos);
668
- else
669
- // vPos < 0.
670
- expr = expr.floorDiv (-vPos);
671
- // Successfully constructed expression.
672
- memo[pos] = expr;
673
- changed = true ;
688
+ }
674
689
}
675
690
// This loop is guaranteed to reach a fixed point - since once an
676
691
// variable's explicit form is computed (in memo[pos]), it's not updated
@@ -891,6 +906,185 @@ FlatLinearConstraints::computeLocalVars(SmallVectorImpl<AffineExpr> &memo,
891
906
llvm::all_of (localExprs, [](AffineExpr expr) { return expr; }));
892
907
}
893
908
909
+ // / Given an equality or inequality (`isEquality` used to disambiguate) of `cst`
910
+ // / at `idx`, traverse and sum up `AffineExpr`s of all known ids other than the
911
+ // / `pos`th. Known `AffineExpr`s are given in `exprs` (unknowns are null). If
912
+ // / the equality/inequality contains any unknown id, return None. Otherwise
913
+ // / return sum as `AffineExpr`.
914
+ static std::optional<AffineExpr> getAsExpr (const FlatLinearConstraints &cst,
915
+ unsigned pos, MLIRContext *context,
916
+ ArrayRef<AffineExpr> exprs,
917
+ unsigned idx, bool isEquality) {
918
+ // Initialize with a `0` expression.
919
+ auto expr = getAffineConstantExpr (0 , context);
920
+
921
+ SmallVector<int64_t , 8 > row =
922
+ isEquality ? cst.getEquality64 (idx) : cst.getInequality64 (idx);
923
+
924
+ // Traverse `idx`th equality and construct the possible affine expression in
925
+ // terms of known identifiers.
926
+ unsigned j, e;
927
+ for (j = 0 , e = cst.getNumVars (); j < e; ++j) {
928
+ if (j == pos)
929
+ continue ;
930
+ int64_t c = row[j];
931
+ if (c == 0 )
932
+ continue ;
933
+ // If any of the involved IDs hasn't been found yet, we can't proceed.
934
+ if (!exprs[j])
935
+ break ;
936
+ expr = expr + exprs[j] * c;
937
+ }
938
+ if (j < e)
939
+ // Can't construct expression as it depends on a yet uncomputed
940
+ // identifier.
941
+ return std::nullopt;
942
+
943
+ // Add constant term to AffineExpr.
944
+ expr = expr + row[cst.getNumVars ()];
945
+ return expr;
946
+ }
947
+
948
+ std::optional<int64_t > FlatLinearConstraints::getConstantBoundOnDimSize (
949
+ MLIRContext *context, unsigned pos, AffineMap *lb, AffineMap *ub,
950
+ unsigned *minLbPos, unsigned *minUbPos) const {
951
+
952
+ assert (pos < getNumDimVars () && " Invalid identifier position" );
953
+
954
+ auto freeOfUnknownLocalVars = [&](ArrayRef<int64_t > cst,
955
+ ArrayRef<AffineExpr> whiteListCols) {
956
+ for (int i = getNumDimAndSymbolVars (), e = cst.size () - 1 ; i < e; ++i) {
957
+ if (whiteListCols[i] && whiteListCols[i].isSymbolicOrConstant ())
958
+ continue ;
959
+ if (cst[i] != 0 )
960
+ return false ;
961
+ }
962
+ return true ;
963
+ };
964
+
965
+ // Detect the necesary local variables first.
966
+ SmallVector<AffineExpr, 8 > memo (getNumVars (), AffineExpr ());
967
+ (void )computeLocalVars (memo, context);
968
+
969
+ // Find an equality for 'pos'^th identifier that equates it to some function
970
+ // of the symbolic identifiers (+ constant).
971
+ int eqPos = findEqualityToConstant (pos, /* symbolic=*/ true );
972
+ // If the equality involves a local var that can not be expressed as a
973
+ // symbolic or constant affine expression, we bail out.
974
+ if (eqPos != -1 && freeOfUnknownLocalVars (getEquality64 (eqPos), memo)) {
975
+ // This identifier can only take a single value.
976
+ if (lb && detectAsExpr (*this , pos, eqPos, context, memo)) {
977
+ AffineExpr equalityExpr =
978
+ simplifyAffineExpr (memo[pos], 0 , getNumSymbolVars ());
979
+ *lb = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), equalityExpr);
980
+ if (ub)
981
+ *ub = *lb;
982
+ }
983
+ if (minLbPos)
984
+ *minLbPos = eqPos;
985
+ if (minUbPos)
986
+ *minUbPos = eqPos;
987
+ return 1 ;
988
+ }
989
+
990
+ // Positions of constraints that are lower/upper bounds on the variable.
991
+ SmallVector<unsigned , 4 > lbIndices, ubIndices;
992
+
993
+ // Note inequalities that give lower and upper bounds.
994
+ getLowerAndUpperBoundIndices (pos, &lbIndices, &ubIndices,
995
+ /* eqIndices=*/ nullptr , /* offset=*/ 0 ,
996
+ /* num=*/ getNumDimVars ());
997
+
998
+ std::optional<int64_t > minDiff = std::nullopt;
999
+ unsigned minLbPosition = 0 , minUbPosition = 0 ;
1000
+ AffineExpr minLbExpr, minUbExpr;
1001
+
1002
+ // Traverse each lower bound and upper bound pair, to compute the difference
1003
+ // between them.
1004
+ for (unsigned ubPos : ubIndices) {
1005
+ // Construct sum of all ids other than `pos`th in the given upper bound row.
1006
+ std::optional<AffineExpr> maybeUbExpr =
1007
+ getAsExpr (*this , pos, context, memo, ubPos, /* isEquality=*/ false );
1008
+ if (!maybeUbExpr.has_value () || !(*maybeUbExpr).isSymbolicOrConstant ())
1009
+ continue ;
1010
+
1011
+ // Canonical form of an inequality that constrains the upper bound on
1012
+ // an id `x_i` is of the form:
1013
+ // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` <= -1.
1014
+ // Therefore the upper bound on `x_i` will be
1015
+ // `(
1016
+ // sum(c_j*x_j) where j != i
1017
+ // +
1018
+ // c_0
1019
+ // )
1020
+ // /
1021
+ // -(c_i)`. Divison here is a floorDiv.
1022
+ AffineExpr ubExpr = maybeUbExpr->floorDiv (-atIneq64 (ubPos, pos));
1023
+ assert (-atIneq64 (ubPos, pos) > 0 && " invalid upper bound index" );
1024
+
1025
+ // Go over each lower bound.
1026
+ for (unsigned lbPos : lbIndices) {
1027
+ // Construct sum of all ids other than `pos`th in the given lower bound
1028
+ // row.
1029
+ std::optional<AffineExpr> maybeLbExpr =
1030
+ getAsExpr (*this , pos, context, memo, lbPos, /* isEquality=*/ false );
1031
+ if (!maybeLbExpr.has_value () || !(*maybeLbExpr).isSymbolicOrConstant ())
1032
+ continue ;
1033
+
1034
+ // Canonical form of an inequality that is constraining the lower bound
1035
+ // on an id `x_i is of the form:
1036
+ // `c_1*x_1 + c_2*x_2 + ... + c_0 >= 0`, where `c_i` >= 1.
1037
+ // Therefore upperBound on `x_i` will be
1038
+ // `-(
1039
+ // sum(c_j*x_j) where j != i
1040
+ // +
1041
+ // c_0
1042
+ // )
1043
+ // /
1044
+ // c_i`. Divison here is a ceilDiv.
1045
+ int64_t divisor = atIneq64 (lbPos, pos);
1046
+ // We convert the `ceilDiv` for floordiv with the formula:
1047
+ // `expr ceildiv divisor is (expr + divisor - 1) floordiv divisor`,
1048
+ // since uniformly keeping divisons as `floorDiv` helps their
1049
+ // simplification.
1050
+ AffineExpr lbExpr = (-(*maybeLbExpr) + divisor - 1 ).floorDiv (divisor);
1051
+ assert (atIneq64 (lbPos, pos) > 0 && " invalid lower bound index" );
1052
+
1053
+ AffineExpr difference =
1054
+ simplifyAffineExpr (ubExpr - lbExpr + 1 , 0 , getNumSymbolVars ());
1055
+ // If the difference is not constant, ignore the lower bound - upper bound
1056
+ // pair.
1057
+ auto constantDiff = dyn_cast<AffineConstantExpr>(difference);
1058
+ if (!constantDiff)
1059
+ continue ;
1060
+
1061
+ int64_t diffValue = constantDiff.getValue ();
1062
+ // This bound is non-negative by definition.
1063
+ diffValue = std::max<int64_t >(diffValue, 0 );
1064
+ if (!minDiff || diffValue < *minDiff) {
1065
+ minDiff = diffValue;
1066
+ minLbPosition = lbPos;
1067
+ minUbPosition = ubPos;
1068
+ minLbExpr = lbExpr;
1069
+ minUbExpr = ubExpr;
1070
+ }
1071
+ }
1072
+ }
1073
+
1074
+ // Populate outputs where available and needed.
1075
+ if (lb && minDiff) {
1076
+ *lb = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), minLbExpr);
1077
+ }
1078
+ if (ub)
1079
+ *ub = AffineMap::get (/* dimCount=*/ 0 , getNumSymbolVars (), minUbExpr);
1080
+ if (minLbPos)
1081
+ *minLbPos = minLbPosition;
1082
+ if (minUbPos)
1083
+ *minUbPos = minUbPosition;
1084
+
1085
+ return minDiff;
1086
+ }
1087
+
894
1088
IntegerSet FlatLinearConstraints::getAsIntegerSet (MLIRContext *context) const {
895
1089
if (getNumConstraints () == 0 )
896
1090
// Return universal set (always true): 0 == 0.
0 commit comments