Skip to content

Commit 45c226d

Browse files
authored
[MLIR] Add ODS support for generating helpers for dialect (discardable) attributes (llvm#77024)
This is a new ODS feature that allows dialects to define a list of key/value pair representing an attribute type and a name. This will generate helper classes on the dialect to be able to manage discardable attributes on operations in a type safe way. For example the `test` dialect can define: ``` let discardableAttrs = (ins "mlir::IntegerAttr":$discardable_attr_key, ); ``` And the following will be generated in the TestDialect class: ``` /// Helper to manage the discardable attribute `discardable_attr_key`. class DiscardableAttrKeyAttrHelper { ::mlir::StringAttr name; public: static constexpr ::llvm::StringLiteral getNameStr() { return "test.discardable_attr_key"; } constexpr ::mlir::StringAttr getName() { return name; } DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx) : name(::mlir::StringAttr::get(ctx, getNameStr())) {} mlir::IntegerAttr getAttr(::mlir::Operation *op) { return op->getAttrOfType<mlir::IntegerAttr>(name); } void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) { op->setAttr(name, val); } bool isAttrPresent(::mlir::Operation *op) { return op->hasAttrOfType<mlir::IntegerAttr>(name); } void removeAttr(::mlir::Operation *op) { assert(op->hasAttrOfType<mlir::IntegerAttr>(name)); op->removeAttr(name); } }; DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() { return discardableAttrKeyAttrName; } ``` User code having an instance of the TestDialect can then manipulate this attribute on operation using: ``` auto helper = testDialect.getDiscardableAttrKeyAttrHelper(); helper.setAttr(op, value); helper.isAttrPresent(op); ... ```
1 parent 5911334 commit 45c226d

File tree

9 files changed

+136
-29
lines changed

9 files changed

+136
-29
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
2828
let hasOperationAttrVerify = 1;
2929

3030
let extraClassDeclaration = [{
31-
/// Get the name of the attribute used to annotate external kernel
32-
/// functions.
33-
static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
34-
static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
35-
return ::llvm::StringLiteral("rocdl.flat_work_group_size");
36-
}
37-
static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
38-
return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
39-
}
40-
4131
/// The address space value that represents global memory.
4232
static constexpr unsigned kGlobalMemoryAddressSpace = 1;
4333
/// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
4636
static constexpr unsigned kPrivateMemoryAddressSpace = 5;
4737
}];
4838

39+
let discardableAttrs = (ins
40+
"::mlir::UnitAttr":$kernel,
41+
"::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
42+
"::mlir::StringAttr":$flat_work_group_size,
43+
"::mlir::IntegerAttr":$max_flat_work_group_size,
44+
"::mlir::IntegerAttr":$waves_per_eu
45+
);
46+
4947
let useDefaultAttributePrinterParser = 1;
5048
}
5149

mlir/include/mlir/IR/DialectBase.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class Dialect {
3434
// pattern or interfaces.
3535
list<string> dependentDialects = [];
3636

37+
// A list of key/value pair representing an attribute type and a name.
38+
// This will generate helper classes on the dialect to be able to
39+
// manage discardable attributes on operations in a type safe way.
40+
dag discardableAttrs = (ins);
41+
3742
// The C++ namespace that ops of this dialect should be placed into.
3843
//
3944
// By default, uses the name of the dialect as the only namespace. To avoid

mlir/include/mlir/TableGen/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define MLIR_TABLEGEN_DIALECT_H_
1515

1616
#include "mlir/Support/LLVM.h"
17+
#include "llvm/TableGen/Record.h"
18+
1719
#include <string>
1820
#include <vector>
1921

@@ -90,6 +92,10 @@ class Dialect {
9092
/// dialect.
9193
bool usePropertiesForAttributes() const;
9294

95+
llvm::DagInit *getDiscardableAttributes() const;
96+
97+
const llvm::Record *getDef() const { return def; }
98+
9399
// Returns whether two dialects are equal by checking the equality of the
94100
// underlying record.
95101
bool operator==(const Dialect &other) const;

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
285285
configureGpuToROCDLConversionLegality(target);
286286
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
287287
signalPassFailure();
288-
288+
auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
289+
auto reqdWorkGroupSizeAttrHelper =
290+
rocdlDialect->getReqdWorkGroupSizeAttrHelper();
291+
auto flatWorkGroupSizeAttrHelper =
292+
rocdlDialect->getFlatWorkGroupSizeAttrHelper();
289293
// Manually rewrite known block size attributes so the LLVMIR translation
290294
// infrastructure can pick them up.
291-
m.walk([ctx](LLVM::LLVMFuncOp op) {
295+
m.walk([&](LLVM::LLVMFuncOp op) {
292296
if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
293297
op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
294-
op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
295-
blockSizes);
298+
reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
296299
// Also set up the rocdl.flat_work_group_size attribute to prevent
297300
// conflicting metadata.
298301
uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
301304
}
302305
StringAttr flatSizeAttr =
303306
StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
304-
op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
305-
flatSizeAttr);
307+
flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
306308
}
307309
});
308310
}
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
355357
converter,
356358
/*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
357359
/*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
358-
StringAttr::get(&converter.getContext(),
359-
ROCDL::ROCDLDialect::getKernelFuncAttrName()));
360+
ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
360361
if (Runtime::HIP == runtime) {
361362
patterns.add<GPUPrintfOpToHIPLowering>(converter);
362363
} else if (Runtime::OpenCL == runtime) {

mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
253253
LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
254254
NamedAttribute attr) {
255255
// Kernel function attribute should be attached to functions.
256-
if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
256+
if (kernelAttrName.getName() == attr.getName()) {
257257
if (!isa<LLVM::LLVMFuncOp>(op)) {
258-
return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
258+
return op->emitError() << "'" << kernelAttrName.getName()
259259
<< "' attribute attached to unexpected op";
260260
}
261261
}

mlir/lib/TableGen/Dialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
106106
return def->getValueAsBit("usePropertiesForAttributes");
107107
}
108108

109+
llvm::DagInit *Dialect::getDiscardableAttributes() const {
110+
return def->getValueAsDag("discardableAttrs");
111+
}
112+
109113
bool Dialect::operator==(const Dialect &other) const {
110114
return def == other.def;
111115
}

mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
8484
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
8585
NamedAttribute attribute,
8686
LLVM::ModuleTranslation &moduleTranslation) const final {
87-
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
87+
auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
88+
if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
8889
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
8990
if (!func)
9091
return failure();
@@ -99,12 +100,12 @@ class ROCDLDialectLLVMIRTranslationInterface
99100
if (!llvmFunc->hasFnAttribute("amdgpu-flat-work-group-size")) {
100101
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", "1,256");
101102
}
102-
103103
}
104104
// Override flat-work-group-size
105105
// TODO: update clients to rocdl.flat_work_group_size instead,
106106
// then remove this half of the branch
107-
if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
107+
if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
108+
attribute.getName()) {
108109
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
109110
if (!func)
110111
return failure();
@@ -119,7 +120,7 @@ class ROCDLDialectLLVMIRTranslationInterface
119120
attrValueStream << "1," << value.getInt();
120121
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
121122
}
122-
if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
123+
if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
123124
attribute.getName()) {
124125
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
125126
if (!func)
@@ -136,7 +137,7 @@ class ROCDLDialectLLVMIRTranslationInterface
136137
}
137138

138139
// Set reqd_work_group_size metadata
139-
if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
140+
if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
140141
attribute.getName()) {
141142
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
142143
if (!func)

mlir/test/lib/Dialect/Test/TestDialect.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
2525
let useDefaultAttributePrinterParser = 1;
2626
let isExtensible = 1;
2727
let dependentDialects = ["::mlir::DLTIDialect"];
28+
let discardableAttrs = (ins
29+
"mlir::IntegerAttr":$discardable_attr_key,
30+
"SimpleAAttr":$other_discardable_attr_key
31+
);
2832

2933
let extraClassDeclaration = [{
3034
void registerAttributes();

mlir/tools/mlir-tblgen/DialectGen.cpp

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ using DialectFilterIterator =
4343
std::function<bool(const llvm::Record *)>>;
4444
} // namespace
4545

46+
static void populateDiscardableAttributes(
47+
Dialect &dialect, llvm::DagInit *discardableAttrDag,
48+
SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
49+
for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
50+
llvm::Init *arg = discardableAttrDag->getArg(i);
51+
52+
StringRef givenName = discardableAttrDag->getArgNameStr(i);
53+
if (givenName.empty())
54+
PrintFatalError(dialect.getDef()->getLoc(),
55+
"discardable attributes must be named");
56+
discardableAttributes.push_back(
57+
{givenName.str(), arg->getAsUnquotedString()});
58+
}
59+
}
60+
4661
/// Given a set of records for a T, filter the ones that correspond to
4762
/// the given dialect.
4863
template <typename T>
@@ -180,6 +195,44 @@ static const char *const operationInterfaceFallbackDecl = R"(
180195
mlir::OperationName opName) override;
181196
)";
182197

198+
/// The code block for the discardable attribute helper.
199+
static const char *const discardableAttrHelperDecl = R"(
200+
/// Helper to manage the discardable attribute `{1}`.
201+
class {0}AttrHelper {{
202+
::mlir::StringAttr name;
203+
public:
204+
static constexpr ::llvm::StringLiteral getNameStr() {{
205+
return "{4}.{1}";
206+
}
207+
constexpr ::mlir::StringAttr getName() {{
208+
return name;
209+
}
210+
211+
{0}AttrHelper(::mlir::MLIRContext *ctx)
212+
: name(::mlir::StringAttr::get(ctx, getNameStr())) {{}
213+
214+
{2} getAttr(::mlir::Operation *op) {{
215+
return op->getAttrOfType<{2}>(name);
216+
}
217+
void setAttr(::mlir::Operation *op, {2} val) {{
218+
op->setAttr(name, val);
219+
}
220+
bool isAttrPresent(::mlir::Operation *op) {{
221+
return op->hasAttrOfType<{2}>(name);
222+
}
223+
void removeAttr(::mlir::Operation *op) {{
224+
assert(op->hasAttrOfType<{2}>(name));
225+
op->removeAttr(name);
226+
}
227+
};
228+
{0}AttrHelper get{0}AttrHelper() {
229+
return {3}AttrName;
230+
}
231+
private:
232+
{0}AttrHelper {3}AttrName;
233+
public:
234+
)";
235+
183236
/// Generate the declaration for the given dialect class.
184237
static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
185238
// Emit all nested namespaces.
@@ -215,6 +268,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
215268
os << regionResultAttrVerifierDecl;
216269
if (dialect.hasOperationInterfaceFallback())
217270
os << operationInterfaceFallbackDecl;
271+
272+
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
273+
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
274+
populateDiscardableAttributes(dialect, discardableAttrDag,
275+
discardableAttributes);
276+
277+
for (const auto &attrPair : discardableAttributes) {
278+
std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
279+
attrPair.first, /*capitalizeFirst=*/true);
280+
std::string camelName = llvm::convertToCamelFromSnakeCase(
281+
attrPair.first, /*capitalizeFirst=*/false);
282+
os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
283+
attrPair.first, attrPair.second, camelName,
284+
dialect.getName());
285+
}
286+
218287
if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
219288
os << *extraDecl;
220289

@@ -252,9 +321,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
252321
/// {1}: Initialization code that is emitted in the ctor body before calling
253322
/// initialize(), such as dependent dialect registration.
254323
/// {2}: The dialect parent class.
324+
/// {3}: Extra members to initialize
255325
static const char *const dialectConstructorStr = R"(
256326
{0}::{0}(::mlir::MLIRContext *context)
257-
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
327+
: ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
328+
{3}
329+
{{
258330
{1}
259331
initialize();
260332
}
@@ -268,7 +340,9 @@ static const char *const dialectDestructorStr = R"(
268340
269341
)";
270342

271-
static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
343+
static void emitDialectDef(Dialect &dialect,
344+
const llvm::RecordKeeper &recordKeeper,
345+
raw_ostream &os) {
272346
std::string cppClassName = dialect.getCppClassName();
273347

274348
// Emit the TypeID explicit specializations to have a single symbol def.
@@ -295,8 +369,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
295369
// Emit the constructor and destructor.
296370
StringRef superClassName =
297371
dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
372+
373+
llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
374+
SmallVector<std::pair<std::string, std::string>> discardableAttributes;
375+
populateDiscardableAttributes(dialect, discardableAttrDag,
376+
discardableAttributes);
377+
std::string discardableAttributesInit;
378+
for (const auto &attrPair : discardableAttributes) {
379+
std::string camelName = llvm::convertToCamelFromSnakeCase(
380+
attrPair.first, /*capitalizeFirst=*/false);
381+
llvm::raw_string_ostream os(discardableAttributesInit);
382+
os << ", " << camelName << "AttrName(context)";
383+
}
384+
298385
os << llvm::formatv(dialectConstructorStr, cppClassName,
299-
dependentDialectRegistrations, superClassName);
386+
dependentDialectRegistrations, superClassName,
387+
discardableAttributesInit);
300388
if (!dialect.hasNonDefaultDestructor())
301389
os << llvm::formatv(dialectDestructorStr, cppClassName);
302390
}
@@ -313,7 +401,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
313401
std::optional<Dialect> dialect = findDialectToGenerate(dialects);
314402
if (!dialect)
315403
return true;
316-
emitDialectDef(*dialect, os);
404+
emitDialectDef(*dialect, recordKeeper, os);
317405
return false;
318406
}
319407

0 commit comments

Comments
 (0)