Skip to content

Commit a3d4187

Browse files
[mlir][ODS] Optionally generate public C++ functions for type constraints (llvm#104577)
Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field. Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++. Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way. Alternatives considered: 1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files. `#include`s could cause duplicate definitions of the same type constraint. 2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)
1 parent 90556ef commit a3d4187

File tree

11 files changed

+181
-11
lines changed

11 files changed

+181
-11
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Constraints
2+
3+
[TOC]
4+
5+
## Attribute / Type Constraints
6+
7+
When defining the arguments of an operation in TableGen, users can specify
8+
either plain attributes/types or use attribute/type constraints to levy
9+
additional requirements on the attribute value or operand type.
10+
11+
```tablegen
12+
def My_Type1 : MyDialect_Type<"Type1", "type1"> { ... }
13+
def My_Type2 : MyDialect_Type<"Type2", "type2"> { ... }
14+
15+
// Plain type
16+
let arguments = (ins MyType1:$val);
17+
// Type constraint
18+
let arguments = (ins AnyTypeOf<[MyType1, MyType2]>:$val);
19+
```
20+
21+
`AnyTypeOf` is an example for a type constraints. Many useful type constraints
22+
can be found in `mlir/IR/CommonTypeConstraints.td`. Additional verification
23+
code is generated for type/attribute constraints. Type constraints can not only
24+
be used when defining operation arguments, but also when defining type
25+
parameters.
26+
27+
Optionally, C++ functions can be generated, so that type constraints can be
28+
checked from C++. The name of the C++ function must be specified in the
29+
`cppFunctionName` field. If no function name is specified, no C++ function is
30+
emitted.
31+
32+
```tablegen
33+
// Example: Element type constraint for VectorType
34+
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
35+
let cppFunctionName = "isValidVectorTypeElementType";
36+
}
37+
```
38+
39+
The above example tranlates into the following C++ code:
40+
```c++
41+
bool isValidVectorTypeElementType(::mlir::Type type) {
42+
return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
43+
}
44+
```
45+
46+
An extra TableGen rule is needed to emit C++ code for type constraints. This
47+
will generate only the declarations/definitions of the type constaraints that
48+
are defined in the specified `.td` file, but not those that are in included
49+
`.td` files.
50+
51+
```cmake
52+
mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
53+
mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
54+
```
55+
56+
The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
57+
whereever you are referencing the type constraint in C++. Note that no C++
58+
namespace will be emitted by the code generator. The `#include` statements of
59+
the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
198198
#include "mlir/IR/BuiltinTypes.h.inc"
199199

200200
namespace mlir {
201+
#include "mlir/IR/BuiltinTypeConstraints.h.inc"
201202

202203
//===----------------------------------------------------------------------===//
203204
// MemRefType

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10971097
// VectorType
10981098
//===----------------------------------------------------------------------===//
10991099

1100+
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
1101+
let cppFunctionName = "isValidVectorTypeElementType";
1102+
}
1103+
11001104
def Builtin_Vector : Builtin_Type<"Vector", "vector",
11011105
[ShapedTypeInterface, ValueSemantics], "Type"> {
11021106
let summary = "Multi-dimensional SIMD vector type";
@@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11471151
}];
11481152
let parameters = (ins
11491153
ArrayRefParameter<"int64_t">:$shape,
1150-
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
1154+
Builtin_VectorTypeElementType:$elementType,
11511155
ArrayRefParameter<"bool">:$scalableDims
11521156
);
11531157
let builders = [
@@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11711175
class Builder;
11721176

11731177
/// Returns true if the given type can be used as an element of a vector
1174-
/// type. In particular, vectors can consist of integer, index, or float
1175-
/// primitives.
1176-
static bool isValidElementType(Type t) {
1177-
// TODO: Auto-generate this function from $elementType.
1178-
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
1179-
}
1178+
/// type. See "Builtin_VectorTypeElementType" for allowed types.
1179+
static bool isValidElementType(Type t);
11801180

11811181
/// Returns true if the vector contains scalable dimensions.
11821182
bool isScalable() const {

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
3535
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
3636
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
3737
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
38+
mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
39+
mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
40+
add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
3841

3942
set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
4043
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)

mlir/include/mlir/IR/Constraints.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {
149149

150150
// Subclass for constraints on a type.
151151
class TypeConstraint<Pred predicate, string summary = "",
152-
string cppTypeParam = "::mlir::Type"> :
152+
string cppTypeParam = "::mlir::Type",
153+
string cppFunctionNameParam = ""> :
153154
Constraint<predicate, summary> {
154155
// The name of the C++ Type class if known, or Type if not.
155156
string cppType = cppTypeParam;
157+
// The name of the C++ function that is generated for this type constraint.
158+
// If empty, no C++ function is generated.
159+
string cppFunctionName = cppFunctionNameParam;
156160
}
157161

158162
// Subclass for constraints on an attribute.

mlir/include/mlir/TableGen/Constraint.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class Constraint {
6969
/// context on the def).
7070
std::string getUniqueDefName() const;
7171

72+
/// Returns the name of the C++ function that should be generated for this
73+
/// constraint, or std::nullopt if no C++ function should be generated.
74+
std::optional<StringRef> getCppFunctionName() const;
75+
7276
Kind getKind() const { return kind; }
7377

7478
/// Return the underlying def.

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ using namespace mlir::detail;
3232
#define GET_TYPEDEF_CLASSES
3333
#include "mlir/IR/BuiltinTypes.cpp.inc"
3434

35+
namespace mlir {
36+
#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37+
} // namespace mlir
38+
3539
//===----------------------------------------------------------------------===//
3640
// BuiltinDialect
3741
//===----------------------------------------------------------------------===//
@@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
230234
// VectorType
231235
//===----------------------------------------------------------------------===//
232236

237+
bool VectorType::isValidElementType(Type t) {
238+
return isValidVectorTypeElementType(t);
239+
}
240+
233241
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
234242
ArrayRef<int64_t> shape, Type elementType,
235243
ArrayRef<bool> scalableDims) {
@@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
278286
[](auto type) { return type.getElementType(); });
279287
}
280288

281-
bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
289+
bool TensorType::hasRank() const {
290+
return !llvm::isa<UnrankedTensorType>(*this);
291+
}
282292

283293
ArrayRef<int64_t> TensorType::getShape() const {
284294
return llvm::cast<RankedTensorType>(*this).getShape();
@@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
365375
[](auto type) { return type.getElementType(); });
366376
}
367377

368-
bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
378+
bool BaseMemRefType::hasRank() const {
379+
return !llvm::isa<UnrankedMemRefType>(*this);
380+
}
369381

370382
ArrayRef<int64_t> BaseMemRefType::getShape() const {
371383
return llvm::cast<MemRefType>(*this).getShape();

mlir/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
5555
MLIRBuiltinLocationAttributesIncGen
5656
MLIRBuiltinOpsIncGen
5757
MLIRBuiltinTypesIncGen
58+
MLIRBuiltinTypeConstraintsIncGen
5859
MLIRBuiltinTypeInterfacesIncGen
5960
MLIRCallInterfacesIncGen
6061
MLIRCastInterfacesIncGen

mlir/lib/TableGen/Constraint.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
3030
kind = CK_Region;
3131
} else if (def->isSubClassOf("SuccessorConstraint")) {
3232
kind = CK_Successor;
33-
} else if(!def->isSubClassOf("Constraint")) {
33+
} else if (!def->isSubClassOf("Constraint")) {
3434
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
3535
llvm::report_fatal_error("Abort");
3636
}
@@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
109109
}
110110
}
111111

112+
std::optional<StringRef> Constraint::getCppFunctionName() const {
113+
std::optional<StringRef> name =
114+
def->getValueAsOptionalString("cppFunctionName");
115+
if (!name || *name == "")
116+
return std::nullopt;
117+
return name;
118+
}
119+
112120
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
113121
llvm::StringRef self,
114122
std::vector<std::string> &&entities)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
2+
// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
3+
4+
include "mlir/IR/CommonTypeConstraints.td"
5+
6+
def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
7+
let cppFunctionName = "isValidDummy";
8+
}
9+
10+
// DECL: bool isValidDummy(::mlir::Type type);
11+
12+
// DEF: bool isValidDummy(::mlir::Type type) {
13+
// DEF: return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
14+
// DEF: }

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,55 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10231023
return false;
10241024
}
10251025

1026+
//===----------------------------------------------------------------------===//
1027+
// Type Constraints
1028+
//===----------------------------------------------------------------------===//
1029+
1030+
/// Find all type constraints for which a C++ function should be generated.
1031+
static std::vector<Constraint>
1032+
getAllTypeConstraints(const llvm::RecordKeeper &records) {
1033+
std::vector<Constraint> result;
1034+
for (llvm::Record *def :
1035+
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
1036+
// Ignore constraints defined outside of the top-level file.
1037+
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
1038+
llvm::SrcMgr.getMainFileID())
1039+
continue;
1040+
Constraint constr(def);
1041+
// Generate C++ function only if "cppFunctionName" is set.
1042+
if (!constr.getCppFunctionName())
1043+
continue;
1044+
result.push_back(constr);
1045+
}
1046+
return result;
1047+
}
1048+
1049+
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
1050+
raw_ostream &os) {
1051+
static const char *const typeConstraintDecl = R"(
1052+
bool {0}(::mlir::Type type);
1053+
)";
1054+
1055+
for (Constraint constr : getAllTypeConstraints(records))
1056+
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
1057+
}
1058+
1059+
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
1060+
raw_ostream &os) {
1061+
static const char *const typeConstraintDef = R"(
1062+
bool {0}(::mlir::Type type) {
1063+
return ({1});
1064+
}
1065+
)";
1066+
1067+
for (Constraint constr : getAllTypeConstraints(records)) {
1068+
FmtContext ctx;
1069+
ctx.withSelf("type");
1070+
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
1071+
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
1072+
}
1073+
}
1074+
10261075
//===----------------------------------------------------------------------===//
10271076
// GEN: Registration hooks
10281077
//===----------------------------------------------------------------------===//
@@ -1070,3 +1119,18 @@ static mlir::GenRegistration
10701119
TypeDefGenerator generator(records, os);
10711120
return generator.emitDecls(typeDialect);
10721121
});
1122+
1123+
static mlir::GenRegistration
1124+
genTypeConstrDefs("gen-type-constraint-defs",
1125+
"Generate type constraint definitions",
1126+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1127+
emitTypeConstraintDefs(records, os);
1128+
return false;
1129+
});
1130+
static mlir::GenRegistration
1131+
genTypeConstrDecls("gen-type-constraint-decls",
1132+
"Generate type constraint declarations",
1133+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1134+
emitTypeConstraintDecls(records, os);
1135+
return false;
1136+
});

0 commit comments

Comments
 (0)