Skip to content

Commit

Permalink
[mlir][vector] Fold extract(shape_cast) for same element count
Browse files Browse the repository at this point in the history
Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D157930
  • Loading branch information
antiagainst committed Aug 15, 2023
1 parent 9f37c21 commit 7897a94
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1808,12 +1808,34 @@ class ExtractOpNonSplatConstantFolder final
}
};

// Folds extract(shape_cast(..)) into shape_cast when the total element count
// does not change.
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
PatternRewriter &rewriter) {
auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
if (!castOp)
return failure();

VectorType sourceType = castOp.getSourceVectorType();
auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
if (!targetType)
return failure();

if (sourceType.getNumElements() != targetType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
castOp.getSource());
return success();
}

} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
ExtractOpFromBroadcast>(context);
results.add(foldExtractFromShapeCastToShapeCast);
}

static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,18 @@ func.func @dont_fold_0d_extract_shapecast(%arg0 : vector<f32>) -> f32 {

// -----

// CHECK-LABEL: fold_extract_shapecast_to_shapecast
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
// CHECK: return %[[R]]
func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
%0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
%r = vector.extract %0[0] : vector<1x12xf32>
return %r : vector<12xf32>
}

// -----

// CHECK-LABEL: dont_fold_expand_collapse
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
Expand Down

0 comments on commit 7897a94

Please sign in to comment.