Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)

12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

//===----------------------------------------------------------------------===//
// TOSA Operator Trait.
//===----------------------------------------------------------------------===//
// Op operands with TOSA shape types must be compile time resolvable
def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//

class Tosa_Op<string mnemonic, list<Trait> traits = []> :
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
TosaResolvableShapeOperands])> {
}

class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,55 @@ template <typename ConcreteType>
class TosaElementwiseOperator
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};

LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
/// This class verifies that tosa shape operands are compile time resolvable
template <typename ConcreteType>
class TosaResolvableShapeOperands
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaResolvableShapeOperands(op);
}
};

LogicalResult verifyTosaShapeOperator(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperator(op);
}
};

LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperatorWithSameRanks
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperatorWithSameRanks(op);
}
};

} // namespace tosa
} // namespace OpTrait

namespace tosa {

bool isa_tosa_shape_type(mlir::Type t);

} // namespace tosa

} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1713,12 +1713,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {

let arguments = (ins
Tosa_Tensor:$input1,
DenseI64ArrayAttr:$multiples);
Tosa_Shape:$multiples);

let results = (outs
Tosa_Tensor:$output
);

let extraClassDeclaration = [{
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
}];

let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -2130,4 +2134,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [

include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"

#endif // TOSA_OPS
77 changes: 77 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines shape operators for the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#ifndef TOSA_SHAPE_OPS
#define TOSA_SHAPE_OPS

include "mlir/IR/OpBase.td"

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"

include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"

// Op trait: operator has operands and results with TOSA shape type
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";

let hasFolder = 1;
}

// op trait: shape operator has same ranks for operands and results
def TosaShapeOperatorWithSameRanks
: NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_ShapeOp<mnemonic,
!listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
}


//===----------------------------------------------------------------------===//
// Operator: ConstShape
//===----------------------------------------------------------------------===//
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let summary = "Constant Shape op.";

let description = [{
A node containing constant data for use as the input to an shape operation. May
hold data only in index data type.

Example:

```mlir
// Generic form
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
```
}];

let arguments = (ins IndexElementsAttr : $value);

let results = (outs Tosa_Shape : $output);

let hasVerifier = 1;
}

#endif // TOSA_SHAPE_OPS
65 changes: 65 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

include "mlir/Dialect/Tosa/IR/TosaOpBase.td"

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -218,4 +221,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//

// The base class for Tosa dialect types.
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Tosa_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// ShapeType
//===----------------------------------------------------------------------===//
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
let summary = "Shape with static rank and Index element type";
let description = [{
Syntax:

``` shape - type :: = `shape` `<` rank `>`
``` Values with shape type represents a shape with a fixed rank and a list
of dimensions
.Rank must be zero or a positive integer
.Each dimension is represented by the builtin
Index type.

Examples:

```mlir
// Shape with rank of four, for example, [1, 1, 8, 16]:
!tosa
.shape<4>

// Shape with rank of one, for example, [16]:
!tosa
.shape<1>

// Shape with rank zero, for example, [] (i.e., shape of scalar values):
!tosa.shape<0>
```
}];
let parameters = (ins "int" : $rank);
let builders = [TypeBuilder<(ins "int" : $rank)>];
let assemblyFormat = "`<` $rank `>`";

let genVerifyDecl = 1;
}

def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;

// Whether a Tosa Shape type has a rank equal to the specified rank.
class IsTosaShapeOfRankPred<int rank> : And<[
IsTosaShapeType,
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
]>;

class TosaShapeOfRank<int rank>
: Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;

def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

#endif // TOSA_TYPES_BASE
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2033,7 +2033,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

ArrayRef<int64_t> multiples = operands.getMultiples();
SmallVector<int64_t> multiples;
if (failed(op.getConstantMultiples(multiples)))
return failure();

// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ApplyScaleOp>();
target.addLegalOp<tosa::IfOp>();
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::ConstShapeOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::ConcatOp>();
target.addLegalOp<tosa::SliceOp>();
Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
Expand Down Expand Up @@ -1318,9 +1320,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();
if (getInput1().getType() == getType()) {
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getMultiples())) {
if (multiples.isSplat() &&
multiples.getSplatValue<APInt>().getSExtValue() == 1)
return getInput1();
if (auto int_array_attr =
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
if (llvm::all_of(int_array_attr.getValues<APInt>(),
[](APInt v) { return v.getSExtValue() == 1; }))
return getInput1();
}
}
}
return {};
}

Expand Down
Loading
Loading