diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 1aaf546c231..24db6f14f35 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -87,6 +87,10 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); +ValueTensorType getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype); +Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9b415d685f5..757dd2d0b86 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -674,7 +674,6 @@ class DecomposeAtenTriuOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTriuOp op, PatternRewriter &rewriter) const override { - MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = cast(input.getType()); @@ -685,37 +684,50 @@ class DecomposeAtenTriuOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); } - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value none = rewriter.create(loc); - Value rowDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-2)); - Value colDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value rowSize = rewriter.create(loc, input, rowDim); - Value colSize = rewriter.create(loc, input, colDim); - - Value rowArange = rewriter.create( - loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - Value colArange = rewriter.create( - loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + Value rowSize = getTensorDimSize(rewriter, input, -2); + Value colSize = getTensorDimSize(rewriter, input, -1); + + auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type); + auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type); + auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type); + + Value rowArange = + rewriter.create(loc, rowArrangeType, rowSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + Value colArange = + rewriter.create(loc, colArrangeType, colSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + auto unsqueezeRowArangeInfo = + unsqueezeTensor(rewriter, op, rowArange, cstOne); + auto unsqueezeColArangeInfo = + unsqueezeTensor(rewriter, op, colArange, cstZero); + + if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } - Value unsqueezeRowArange = - rewriter.create(loc, baseType, rowArange, cstOne); - Value unsqueezeColArange = - rewriter.create(loc, baseType, colArange, cstZero); + Value unsqueezeRowArange = unsqueezeRowArangeInfo.value(); + Value unsqueezeColArange = unsqueezeColArangeInfo.value(); Value unsqueezeRowArangePlusDiagonal = rewriter.create( - loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); + loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), + cstOne); + auto boolType = rewriter.getI1Type(); + auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType); Value condTensor = rewriter.create( - loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 8101a2a5b4b..197f09c66b9 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -289,6 +289,32 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } +ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype) { + assert(!shapes.empty() && "shape vector cannot be empty"); + SmallVector shapeInts; + for (Value shape : shapes) { + int64_t dim; + if (matchPattern(shape, m_TorchConstantInt(&dim))) + shapeInts.push_back(dim); + else + shapeInts.push_back(kUnknownSize); + } + return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype); +} + +// Helper function to get the size of the tensor at the given dimension. +Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor, + int64_t dim) { + auto loc = tensor.getLoc(); + auto dimVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + // Use 'createOrFold' instead of 'create': + // If the dimension is a constant, then the AtenSizeIntOp is folded to a + // ContantIntOp. + return rewriter.createOrFold(loc, tensor, dimVal); +} + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,