From 3e3f092bc2bc5e6bc316b9a1d667005f5c2710b7 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Thu, 6 Mar 2025 07:17:24 -0700 Subject: [PATCH] Ensure 0 <= x mod N < N semantics --- mlir/lib/Analysis/AffineExprBounds.cpp | 26 ++++++++++++++----- .../Analysis/test-affine-expr-bounds.mlir | 10 +++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Analysis/AffineExprBounds.cpp b/mlir/lib/Analysis/AffineExprBounds.cpp index 92a63d0004687..b71cfe4721323 100644 --- a/mlir/lib/Analysis/AffineExprBounds.cpp +++ b/mlir/lib/Analysis/AffineExprBounds.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineExprBounds.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" @@ -158,13 +159,24 @@ AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) { return failure(); } LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) { - inferBinOpRange( - expr, [boundsSigned = boundsSigned](ArrayRef ranges) { - if (boundsSigned) { - return intrange::inferRemS(ranges); - } - return intrange::inferRemU(ranges); - }); + // Only support integers >= 1 as RHS. + auto rhsConst = dyn_cast(expr.getRHS()); + if (!rhsConst || rhsConst.getValue() < 1) + return failure(); + + inferBinOpRange(expr, [boundsSigned = + boundsSigned](ArrayRef ranges) { + // Mod must return a value between 0 and N-1. + // Computing (N + (expr mod N)) mod N is guaranteed to yield a result in + // this range. + if (boundsSigned) { + auto rhs = ranges[1]; + auto lhs = ranges[0]; + return intrange::inferRemS( + {intrange::inferAdd({intrange::inferRemS({lhs, rhs}), rhs}), rhs}); + } + return intrange::inferRemU(ranges); + }); return success(); } LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) { diff --git a/mlir/test/Analysis/test-affine-expr-bounds.mlir b/mlir/test/Analysis/test-affine-expr-bounds.mlir index e4af66f1b8d13..03115760a29d0 100644 --- a/mlir/test/Analysis/test-affine-expr-bounds.mlir +++ b/mlir/test/Analysis/test-affine-expr-bounds.mlir @@ -52,6 +52,16 @@ func.func @test_compute_affine_expr_bounds() { // CHECK-SAME: expr_ub = 3 "test.mod_not_wrapping_around"() {affine_map = affine_map<(d0) -> (((d0 + 12) mod 11) mod 5)>, lbs = [0], ubs = [2]} : () -> () + // CHECK: "test.mod_neg"() + // CHECK-SAME: expr_lb = 1 + // CHECK-SAME: expr_ub = 3 + "test.mod_neg"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-4], ubs = [-2]} : () -> () + + // CHECK: "test.mod_wrapping_by_zero"() + // CHECK-SAME: expr_lb = 0 + // CHECK-SAME: expr_ub = 4 + "test.mod_wrapping_by_zero"() {affine_map = affine_map<(d0) -> (d0 mod 5)>, lbs = [-2], ubs = [1]} : () -> () + // FloorDiv // CHECK: "test.floordiv_basic"()