Skip to content

Commit

Permalink
[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimO…
Browse files Browse the repository at this point in the history
…p canonicalization (llvm#84225)

The current canonicalization of `memref.dim` operating on the result of
`memref.reshape` into `memref.load` is incorrect as it doesn't check
whether the `index` operand of `memref.dim` dominates the source
`memref.reshape` op. It always introduces `memref.load` right after
`memref.reshape` to ensure the `memref` is not mutated before the
`memref.load` call. As a result, the following error is observed:

```
$> mlir-opt --canonicalize input.mlir

func.func @reshape_dim(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
    %c4 = arith.constant 4 : index
    %reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
    %0 = arith.muli %arg2, %c4 : index
    %dim = memref.dim %reshape, %0 : memref<*xf32>
    return %dim : index
  }
```

results in:

```
dominator.mlir:22:12: error: operand #1 does not dominate this use
    %dim = memref.dim %reshape, %0 : memref<*xf32>
           ^
dominator.mlir:22:12: note: see current operation: %1 = "memref.load"(%arg1, %2) <{nontemporal = false}> : (memref<?xindex>, index) -> index
dominator.mlir:21:10: note: operand defined here (op in the same block)
    %0 = arith.muli %arg2, %c4 : index
```

Properly fixing this issue requires a dominator analysis which is
expensive to run within a canonicalization pattern. So, this patch fixes
the canonicalization pattern by being more strict/conservative about the
legality condition in which we perform this canonicalization.
The more general pattern is also added to `tensor.dim`. Since tensors are
immutable we don't need to worry about where to introduce the
`tensor.extract` call after canonicalization.
  • Loading branch information
sahas3 committed Mar 12, 2024
1 parent 2a30684 commit 26722f5
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 2 deletions.
32 changes: 31 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,37 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();
return rewriter.notifyMatchFailure(
dim, "Dim op is not defined by a reshape op.");

// dim of a memref reshape can be folded if dim.getIndex() dominates the
// reshape. Instead of using `DominanceInfo` (which is usually costly) we
// cheaply check that either of the following conditions hold:
// 1. dim.getIndex() is defined in the same block as reshape but before
// reshape.
// 2. dim.getIndex() is defined in a parent block of
// reshape.

// Check condition 1
if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
if (reshape->isBeforeInBlock(definingOp)) {
return rewriter.notifyMatchFailure(
dim,
"dim.getIndex is not defined before reshape in the same block.");
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
} // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
// If dim and reshape are in the same block but dim.getIndex() isn't, we
// already know dim.getIndex() dominates reshape without calling
// `isProperAncestor`
return rewriter.notifyMatchFailure(
dim, "dim.getIndex does not dominate reshape.");
}

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};

/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
/// operand.
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();

// Since tensors are immutable we don't need to worry about where to place
// the extract call
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
};
} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
return %dim : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_for(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = memref.dim %0, %arg2 : memref<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}

// -----

// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2287,3 +2287,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
// CHECK-NOT: tensor.store
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[DIM]] : index
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// Update the shape to test that the load ends up in the right place.
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_i32(
// CHECK: tensor.extract
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[CAST]] : index
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
return %dim : index
}

0 comments on commit 26722f5

Please sign in to comment.