Skip to content

Commit

Permalink
[StableHLO] Port gather canonicalization pattern (iree-org#13771)
Browse files Browse the repository at this point in the history
This pattern folds gather into slice (+ reshape) and is ported from
MHLO.

Issue: iree-org#12678
  • Loading branch information
kuhar authored and NatashaKnk committed Jul 6, 2023
1 parent d432ed8 commit ebad8bd
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,82 @@ struct GetDimensionSizeOpCanon final
}
};

/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
struct GatherOpCanon final : OpRewritePattern<mlir::stablehlo::GatherOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::GatherOp gather,
PatternRewriter &rewriter) const override {
DenseIntElementsAttr index;
if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) {
return failure();
}

mlir::stablehlo::GatherDimensionNumbersAttr dnums =
gather.getDimensionNumbers();
if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) {
return failure();
}

// TODO: Remove when the verifier catches this case what is
// invalid if all previous condition holds.
if (index.getNumElements() !=
static_cast<int64_t>(dnums.getStartIndexMap().size())) {
return failure();
}

auto operandType =
dyn_cast<RankedTensorType>(gather->getOperand(0).getType());
if (!operandType || !operandType.hasStaticShape()) return failure();

auto sliceEnd =
llvm::to_vector(gather.getSliceSizes().getValues<int64_t>());
SmallVector<int64_t> sliceStart(sliceEnd.size(), 0);
for (auto [mapIndex, value] :
llvm::zip_equal(dnums.getStartIndexMap(), index.getValues<APInt>())) {
// Clamp the indices within bounds to faithfully mirror gather semantics.
int64_t offset =
std::clamp(value.getSExtValue(), static_cast<int64_t>(0),
operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]);
sliceStart[mapIndex] += offset;
sliceEnd[mapIndex] += offset;
}

SmallVector<int64_t> sliceStride(sliceEnd.size(), 1);
SmallVector<int64_t> sliceShape(sliceEnd.size());
for (auto [shapeElem, startElem, endElem] :
llvm::zip_equal(sliceShape, sliceStart, sliceEnd)) {
shapeElem = endElem - startElem;
}

Type elementType = gather.getType().getElementType();
auto sliceType = RankedTensorType::get(sliceShape, elementType);
Value result = rewriter.create<mlir::stablehlo::SliceOp>(
gather.getLoc(), sliceType, gather.getOperand(),
rewriter.getI64TensorAttr(sliceStart),
rewriter.getI64TensorAttr(sliceEnd),
rewriter.getI64TensorAttr(sliceStride));

ArrayRef<int64_t> collapsedSliceDims = dnums.getCollapsedSliceDims();
if (!collapsedSliceDims.empty()) {
llvm::SmallVector<int64_t> reshapeShape;
for (auto [idx, dim] : llvm::enumerate(sliceShape)) {
if (!llvm::is_contained(collapsedSliceDims, idx)) {
reshapeShape.push_back(dim);
}
}
auto reshapeType = RankedTensorType::get(reshapeShape, elementType);
result = rewriter.create<mlir::stablehlo::ReshapeOp>(gather.getLoc(),
reshapeType, result);
}

result.setType(gather.getType());
rewriter.replaceOp(gather, result);
return success();
}
};

struct ReshapeOpCanon final : OpRewritePattern<mlir::stablehlo::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -710,7 +786,7 @@ void populateCanonicalizationPatterns(MLIRContext *context,
GetDimensionSizeOpCanon, GetTupleElementOpCanon,
// Shape manipulation(-ish) ops.
BroadcastInDimOpCanon, ConcatenateOpCanon, ConvertOpCanon,
DynamicReshapeOpCanon, ReshapeOpCanon, TransposeOpCanon>(context,
benefit);
DynamicReshapeOpCanon, GatherOpCanon, ReshapeOpCanon, TransposeOpCanon>(
context, benefit);
}
} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,112 @@ func.func @transpose(%arg0: tensor<2xf32>, %arg1: tensor<1x2xf32>, %arg2: tensor
// CHECK-NEXT: return [[ARG0]], [[ARG1]], [[X]], [[ARG2]]
return %a, %b, %c, %d : tensor<2xf32>, tensor<1x2xf32>, tensor<2x1xf32>, tensor<f32>
}

// -----

// CHECK-LABEL: func.func @gather_to_slice
func.func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> {
%0 = arith.constant dense<[1, 2]> : tensor<2xi32>
%1 = "stablehlo.gather"(%arg0, %0) {
dimension_numbers = #stablehlo.gather<
index_vector_dim = 0,
offset_dims = [0, 1, 2],
start_index_map = [0, 2],
>,
indices_are_sorted = false,
slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32>
return %1 : tensor<3x6x5xf32>
// CHECK: %[[RET:.*]] = "stablehlo.slice"(%arg0)
// CHECK-SAME: {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>,
// CHECK-SAME: start_indices = dense<[1, 0, 2]> : tensor<3xi64>,
// CHECK-SAME: strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32>
// CHECK-NEXT: return %[[RET]] : tensor<3x6x5xf32>
}

// -----

// CHECK-LABEL: func.func @gather_scalar_index_to_slice
func.func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> {
%0 = arith.constant dense<1> : tensor<i32>
%1 = "stablehlo.gather"(%arg0, %0) {
dimension_numbers = #stablehlo.gather<
index_vector_dim = 0,
offset_dims = [0, 1, 2],
start_index_map = [2],
>,
indices_are_sorted = false,
slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<i32>) -> tensor<5x6x4xf32>
return %1 : tensor<5x6x4xf32>
// CHECK: %[[RET:.*]] = "stablehlo.slice"(%arg0)
// CHECK-SAME: {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>,
// CHECK-SAME: start_indices = dense<[0, 0, 1]> : tensor<3xi64>,
// CHECK-SAME: strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
// CHECK-NEXT: return %[[RET]] : tensor<5x6x4xf32>
}

// -----

// CHECK-LABEL: func.func @gather_to_slice_reshape
func.func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> {
%0 = arith.constant dense<[1, 2]> : tensor<2xi32>
%1 = "stablehlo.gather"(%arg0, %0) {
dimension_numbers = #stablehlo.gather<
collapsed_slice_dims = [2],
index_vector_dim = 0,
offset_dims = [0, 1],
start_index_map = [0, 2],
>,
indices_are_sorted = false,
slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32>
return %1 : tensor<3x6xf32>
// CHECK: %[[V0:.*]] = "stablehlo.slice"(%arg0)
// CHECK-SAME: {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>,
// CHECK-SAME: start_indices = dense<[1, 0, 2]> : tensor<3xi64>,
// CHECK-SAME: strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32>
// CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<3x6x1xf32>) -> tensor<3x6xf32>
// CHECK-NEXT: return %[[V1]] : tensor<3x6xf32>
}

// -----

// CHECK-LABEL: func.func @gather_to_slice_indices_clamp_upperbound
func.func @gather_to_slice_indices_clamp_upperbound(%arg0 : tensor<4x2xui32>) -> tensor<2xui32> {
%0 = arith.constant dense<4> : tensor<1xi32>
%1 = "stablehlo.gather"(%arg0, %0) {
dimension_numbers = #stablehlo.gather<
offset_dims = [0],
index_vector_dim = 0,
collapsed_slice_dims = [0],
start_index_map = [0]
>, indices_are_sorted = true,
slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
return %1 : tensor<2xui32>
// CHECK: %[[V0:.*]] = "stablehlo.slice"(%arg0)
// CHECK-SAME: {limit_indices = dense<[4, 2]> : tensor<2xi64>,
// CHECK-SAME: start_indices = dense<[3, 0]> : tensor<2xi64>,
// CHECK-SAME: strides = dense<1> : tensor<2xi64>} : (tensor<4x2xui32>) -> tensor<1x2xui32>
// CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32>
// CHECK-NEXT: return %[[V1]] : tensor<2xui32>
}

// -----

// CHECK-LABEL: func.func @gather_to_slice_indices_clamp_lowerbound
func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> tensor<2xui32> {
%0 = arith.constant dense<-1> : tensor<1xi32>
%1 = "stablehlo.gather"(%arg0, %0) {
dimension_numbers = #stablehlo.gather<
offset_dims = [0],
index_vector_dim = 0,
collapsed_slice_dims = [0],
start_index_map = [0]
>, indices_are_sorted = true,
slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xui32>, tensor<1xi32>) -> tensor<2xui32>
return %1 : tensor<2xui32>
// CHECK: %[[V0:.*]] = "stablehlo.slice"(%arg0)
// CHECK-SAME: {limit_indices = dense<[1, 2]> : tensor<2xi64>,
// CHECK-SAME: start_indices = dense<0> : tensor<2xi64>,
// CHECK-SAME: strides = dense<1> : tensor<2xi64>} : (tensor<4x2xui32>) -> tensor<1x2xui32>
// CHECK-NEXT: %[[V1:.*]] = stablehlo.reshape %[[V0]] : (tensor<1x2xui32>) -> tensor<2xui32>
// CHECK-NEXT: return %[[V1]] : tensor<2xui32>
}

0 comments on commit ebad8bd

Please sign in to comment.