Skip to content

Commit

Permalink
[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTri…
Browse files Browse the repository at this point in the history
…uOp (llvm#3330)

I am trying to eliminate 'getWithLeastStaticInformation' in
DecomposeAtenTriuOp. Could you provide me with some suggestions?
@qingyunqu @zjgarvey 
See issue llvm#3312
  • Loading branch information
Xinyu302 authored and Branko Trifkovic committed May 24, 2024
1 parent 313dc33 commit eb8ea86
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);

ValueTensorType getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
Type dtype);
Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down
54 changes: 33 additions & 21 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTriuOp op,
PatternRewriter &rewriter) const override {
MLIRContext *context = op.getContext();
Location loc = op.getLoc();
Value input = op.getSelf();
auto inputType = cast<BaseTensorType>(input.getType());
Expand All @@ -685,37 +684,50 @@ class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
}

auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<ConstantNoneOp>(loc);

Value rowDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value colDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);

Value rowArange = rewriter.create<AtenArangeOp>(
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange = rewriter.create<AtenArangeOp>(
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value rowSize = getTensorDimSize(rewriter, input, -2);
Value colSize = getTensorDimSize(rewriter, input, -1);

auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type);
auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type);
auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type);

Value rowArange =
rewriter.create<AtenArangeOp>(loc, rowArrangeType, rowSize,
/*dtype=*/int64DtypeInt, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);
Value colArange =
rewriter.create<AtenArangeOp>(loc, colArrangeType, colSize,
/*dtype=*/int64DtypeInt, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

auto unsqueezeRowArangeInfo =
unsqueezeTensor(rewriter, op, rowArange, cstOne);
auto unsqueezeColArangeInfo =
unsqueezeTensor(rewriter, op, colArange, cstZero);

if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}

Value unsqueezeRowArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
Value unsqueezeColArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
Value unsqueezeRowArange = unsqueezeRowArangeInfo.value();
Value unsqueezeColArange = unsqueezeColArangeInfo.value();

Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(),
cstOne);

auto boolType = rewriter.getI1Type();
auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType);
Value condTensor = rewriter.create<AtenGeTensorOp>(
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);

rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, op.getResult().getType(), condTensor, input, cstZero);
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,32 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
return updatedShape;
}

ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef<Value> shapes,
Type dtype) {
assert(!shapes.empty() && "shape vector cannot be empty");
SmallVector<int64_t> shapeInts;
for (Value shape : shapes) {
int64_t dim;
if (matchPattern(shape, m_TorchConstantInt(&dim)))
shapeInts.push_back(dim);
else
shapeInts.push_back(kUnknownSize);
}
return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype);
}

// Helper function to get the size of the tensor at the given dimension.
Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor,
int64_t dim) {
auto loc = tensor.getLoc();
auto dimVal =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
// Use 'createOrFold' instead of 'create':
// If the dimension is a constant, then the AtenSizeIntOp is folded to a
// ContantIntOp.
return rewriter.createOrFold<AtenSizeIntOp>(loc, tensor, dimVal);
}

// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down

0 comments on commit eb8ea86

Please sign in to comment.