diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 0ac000e2bd978..03317439740e6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1585,6 +1585,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [ let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 5c17d281c2ec7..f68d92a672d91 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -742,6 +742,12 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { ShapedType outputType = getType().cast(); if (inputType.hasStaticShape() && outputType.hasStaticShape()) { + if (getNewShape() != outputType.getShape()) { + return emitOpError() << "newShape attribute " << getNewShape() + << " does not match output type " + << outputType.getShape(); + } + int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { @@ -749,6 +755,52 @@ mlir::LogicalResult tosa::ReshapeOp::verify() { << " elements into " << outputElementsNum; } } + + return mlir::success(); +} + +mlir::LogicalResult tosa::SliceOp::verify() { + // TODO: Complete verification + ShapedType inputType = getInput().getType().cast(); + ShapedType outputType = getType().cast(); + + if (inputType.getRank() != outputType.getRank()) { + return emitOpError() << "rank of input (" << inputType.getRank() + << ") and output (" + << outputType.getRank() + << ") must match"; + } + + if (getSize() != outputType.getShape()) { + return emitOpError() << "size attribute " << getSize() + << " does not match output type " + << outputType.getShape(); + } + + if ((int64_t)getStart().size() != inputType.getRank()) { + return emitOpError() << "rank of start (" << getStart().size() + << ") and input (" + << inputType.getRank() + << ") must match"; + } + if ((int64_t)getSize().size() != inputType.getRank()) { + return emitOpError() << "rank of size (" << getSize().size() + << ") and input (" + << inputType.getRank() + << ") must match"; + } + + for (int i = 0; i < outputType.getRank(); ++i) { + auto dimSize = inputType.getShape()[i]; + if (dimSize != ShapedType::kDynamic && getStart()[i] + getSize()[i] > inputType.getShape()[i]) { + return emitOpError() << "start (" << getStart()[i] + << ") plus size (" + << getSize()[i] + << ") goes out of bounds of input size (" + << inputType.getShape()[i] + << ") in dimension " << i; + } + } return mlir::success(); }