Skip to content

Commit 0eed305

Browse files
authored
Revert "[MLIR][TableGen] Use const pointers for various Init objects" (llvm#112506)
Reverts llvm#112316 Bots are failing.
1 parent 7c5d5c0 commit 0eed305

File tree

14 files changed

+68
-80
lines changed

14 files changed

+68
-80
lines changed

mlir/include/mlir/TableGen/AttrOrTypeDef.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class AttrOrTypeParameter {
105105
std::optional<StringRef> getDefaultValue() const;
106106

107107
/// Return the underlying def of this parameter.
108-
const llvm::Init *getDef() const;
108+
llvm::Init *getDef() const;
109109

110110
/// The parameter is pointer-comparable.
111111
bool operator==(const AttrOrTypeParameter &other) const {

mlir/include/mlir/TableGen/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class Dialect {
9292
/// dialect.
9393
bool usePropertiesForAttributes() const;
9494

95-
const llvm::DagInit *getDiscardableAttributes() const;
95+
llvm::DagInit *getDiscardableAttributes() const;
9696

9797
const llvm::Record *getDef() const { return def; }
9898

mlir/include/mlir/TableGen/Operator.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,14 @@ class Operator {
119119

120120
/// A utility iterator over a list of variable decorators.
121121
struct VariableDecoratorIterator
122-
: public llvm::mapped_iterator<const llvm::Init *const *,
123-
VariableDecorator (*)(
124-
const llvm::Init *)> {
122+
: public llvm::mapped_iterator<llvm::Init *const *,
123+
VariableDecorator (*)(llvm::Init *)> {
125124
/// Initializes the iterator to the specified iterator.
126-
VariableDecoratorIterator(const llvm::Init *const *it)
127-
: llvm::mapped_iterator<const llvm::Init *const *,
128-
VariableDecorator (*)(const llvm::Init *)>(
129-
it, &unwrap) {}
130-
static VariableDecorator unwrap(const llvm::Init *init);
125+
VariableDecoratorIterator(llvm::Init *const *it)
126+
: llvm::mapped_iterator<llvm::Init *const *,
127+
VariableDecorator (*)(llvm::Init *)>(it,
128+
&unwrap) {}
129+
static VariableDecorator unwrap(llvm::Init *init);
131130
};
132131
using var_decorator_iterator = VariableDecoratorIterator;
133132
using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;

mlir/lib/TableGen/AttrOrTypeDef.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
4040
auto *builderList =
4141
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
4242
if (builderList && !builderList->empty()) {
43-
for (const llvm::Init *init : builderList->getValues()) {
43+
for (llvm::Init *init : builderList->getValues()) {
4444
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
4545
def->getLoc());
4646

@@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
5858
if (auto *traitList = def->getValueAsListInit("traits")) {
5959
SmallPtrSet<const llvm::Init *, 32> traitSet;
6060
traits.reserve(traitSet.size());
61-
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
62-
[&](const llvm::ListInit *traitList) {
61+
llvm::unique_function<void(llvm::ListInit *)> processTraitList =
62+
[&](llvm::ListInit *traitList) {
6363
for (auto *traitInit : *traitList) {
6464
if (!traitSet.insert(traitInit).second)
6565
continue;
@@ -335,9 +335,7 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
335335
return result && !result->empty() ? result : std::nullopt;
336336
}
337337

338-
const llvm::Init *AttrOrTypeParameter::getDef() const {
339-
return def->getArg(index);
340-
}
338+
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
341339

342340
std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
343341
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
@@ -351,7 +349,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
351349
//===----------------------------------------------------------------------===//
352350

353351
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
354-
const llvm::Init *paramDef = param->getDef();
352+
llvm::Init *paramDef = param->getDef();
355353
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
356354
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
357355
return false;

mlir/lib/TableGen/Attribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
126126
Dialect Attribute::getDialect() const {
127127
const llvm::RecordVal *record = def->getValue("dialect");
128128
if (record && record->getValue()) {
129-
if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
129+
if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
130130
return Dialect(init->getDef());
131131
}
132132
return Dialect(nullptr);

mlir/lib/TableGen/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
106106
return def->getValueAsBit("usePropertiesForAttributes");
107107
}
108108

109-
const llvm::DagInit *Dialect::getDiscardableAttributes() const {
109+
llvm::DagInit *Dialect::getDiscardableAttributes() const {
110110
return def->getValueAsDag("discardableAttrs");
111111
}
112112

mlir/lib/TableGen/Interfaces.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using namespace mlir::tblgen;
2222
//===----------------------------------------------------------------------===//
2323

2424
InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
25-
const llvm::DagInit *args = def->getValueAsDag("arguments");
25+
llvm::DagInit *args = def->getValueAsDag("arguments");
2626
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
2727
arguments.push_back(
2828
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
@@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
7878

7979
// Initialize the interface methods.
8080
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
81-
for (const llvm::Init *init : listInit->getValues())
81+
for (llvm::Init *init : listInit->getValues())
8282
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
8383

8484
// Initialize the interface base classes.
@@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
9898
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
9999
basesAdded.insert(baseInterface.getName());
100100
};
101-
for (const llvm::Init *init : basesInit->getValues())
101+
for (llvm::Init *init : basesInit->getValues())
102102
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
103103
}
104104

mlir/lib/TableGen/Operator.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
161161
StringRef Operator::getCppNamespace() const { return cppNamespace; }
162162

163163
int Operator::getNumResults() const {
164-
const DagInit *results = def.getValueAsDag("results");
164+
DagInit *results = def.getValueAsDag("results");
165165
return results->getNumArgs();
166166
}
167167

@@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
198198
}
199199

200200
TypeConstraint Operator::getResultTypeConstraint(int index) const {
201-
const DagInit *results = def.getValueAsDag("results");
201+
DagInit *results = def.getValueAsDag("results");
202202
return TypeConstraint(cast<DefInit>(results->getArg(index)));
203203
}
204204

205205
StringRef Operator::getResultName(int index) const {
206-
const DagInit *results = def.getValueAsDag("results");
206+
DagInit *results = def.getValueAsDag("results");
207207
return results->getArgNameStr(index);
208208
}
209209

@@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
241241
}
242242

243243
StringRef Operator::getArgName(int index) const {
244-
const DagInit *argumentValues = def.getValueAsDag("arguments");
244+
DagInit *argumentValues = def.getValueAsDag("arguments");
245245
return argumentValues->getArgNameStr(index);
246246
}
247247

@@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
557557
auto *opVarClass = recordKeeper.getClass("OpVariable");
558558
numNativeAttributes = 0;
559559

560-
const DagInit *argumentValues = def.getValueAsDag("arguments");
560+
DagInit *argumentValues = def.getValueAsDag("arguments");
561561
unsigned numArgs = argumentValues->getNumArgs();
562562

563563
// Mapping from name of to argument or result index. Arguments are indexed
@@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
721721
" to precede it in traits list");
722722
};
723723

724-
std::function<void(const llvm::ListInit *)> insert;
725-
insert = [&](const llvm::ListInit *traitList) {
724+
std::function<void(llvm::ListInit *)> insert;
725+
insert = [&](llvm::ListInit *traitList) {
726726
for (auto *traitInit : *traitList) {
727727
auto *def = cast<DefInit>(traitInit)->getDef();
728728
if (def->isSubClassOf("TraitList")) {
@@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
780780
auto *builderList =
781781
dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
782782
if (builderList && !builderList->empty()) {
783-
for (const llvm::Init *init : builderList->getValues())
783+
for (llvm::Init *init : builderList->getValues())
784784
builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
785785
} else if (skipDefaultBuilders()) {
786786
PrintFatalError(
@@ -818,8 +818,7 @@ bool Operator::hasAssemblyFormat() const {
818818
}
819819

820820
StringRef Operator::getAssemblyFormat() const {
821-
return TypeSwitch<const llvm::Init *, StringRef>(
822-
def.getValueInit("assemblyFormat"))
821+
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
823822
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
824823
}
825824

@@ -833,7 +832,7 @@ void Operator::print(llvm::raw_ostream &os) const {
833832
}
834833
}
835834

836-
auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
835+
auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
837836
-> VariableDecorator {
838837
return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
839838
}

mlir/lib/TableGen/Pattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
700700
// The initial benefit value is a heuristic with number of ops in the source
701701
// pattern.
702702
int initBenefit = getSourcePattern().getNumOps();
703-
const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
703+
llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
704704
if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
705705
PrintFatalError(&def,
706706
"The 'addBenefit' takes and only takes one integer value");

mlir/lib/TableGen/Type.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
5050
const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
5151
if (!builderCall || !builderCall->getValue())
5252
return std::nullopt;
53-
return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
53+
return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
5454
builderCall->getValue())
5555
.Case<llvm::StringInit>([&](auto *init) {
5656
StringRef value = init->getValue();

mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
3030
static DeprecatedAction actionOnDeprecatedValue;
3131

3232
// Returns if there is a use of `deprecatedInit` in `field`.
33-
static bool findUse(const Init *field, const Init *deprecatedInit,
34-
llvm::DenseMap<const Init *, bool> &known) {
33+
static bool findUse(Init *field, Init *deprecatedInit,
34+
llvm::DenseMap<Init *, bool> &known) {
3535
if (field == deprecatedInit)
3636
return true;
3737

@@ -64,13 +64,13 @@ static bool findUse(const Init *field, const Init *deprecatedInit,
6464
if (findUse(dagInit->getOperator(), deprecatedInit, known))
6565
return memoize(true);
6666

67-
return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
67+
return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
6868
return findUse(arg, deprecatedInit, known);
6969
}));
7070
}
7171

72-
if (const ListInit *li = dyn_cast<ListInit>(field)) {
73-
return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
72+
if (ListInit *li = dyn_cast<ListInit>(field)) {
73+
return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
7474
return findUse(jt, deprecatedInit, known);
7575
}));
7676
}
@@ -83,8 +83,8 @@ static bool findUse(const Init *field, const Init *deprecatedInit,
8383
}
8484

8585
// Returns if there is a use of `deprecatedInit` in `record`.
86-
static bool findUse(Record &record, const Init *deprecatedInit,
87-
llvm::DenseMap<const Init *, bool> &known) {
86+
static bool findUse(Record &record, Init *deprecatedInit,
87+
llvm::DenseMap<Init *, bool> &known) {
8888
return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
8989
return findUse(val.getValue(), deprecatedInit, known);
9090
});
@@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
100100
if (!r || !r->getValue())
101101
continue;
102102

103-
llvm::DenseMap<const Init *, bool> hasUse;
103+
llvm::DenseMap<Init *, bool> hasUse;
104104
if (auto *si = dyn_cast<StringInit>(r->getValue())) {
105105
for (auto &jt : records.getDefs()) {
106106
// Skip anonymous defs.

mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ class Generator {
4646
private:
4747
/// Emits parse calls to construct given kind.
4848
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
49-
ArrayRef<const Init *> args,
50-
ArrayRef<std::string> argNames, StringRef failure,
51-
mlir::raw_indented_ostream &ios);
49+
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
50+
StringRef failure, mlir::raw_indented_ostream &ios);
5251

5352
/// Emits print instructions.
5453
void emitPrintHelper(const Record *memberRec, StringRef kind,
@@ -136,12 +135,10 @@ void Generator::emitParse(StringRef kind, const Record &x) {
136135
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
137136
mlir::raw_indented_ostream os(output);
138137
std::string returnType = getCType(&x);
139-
os << formatv(head,
140-
kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
141-
x.getName());
142-
const DagInit *members = x.getValueAsDag("members");
143-
SmallVector<std::string> argNames = llvm::to_vector(
144-
map_range(members->getArgNames(), [](const StringInit *init) {
138+
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
139+
DagInit *members = x.getValueAsDag("members");
140+
SmallVector<std::string> argNames =
141+
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
145142
return init->getAsUnquotedString();
146143
}));
147144
StringRef builder = x.getValueAsString("cBuilder").trim();
@@ -151,7 +148,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
151148
}
152149

153150
void printParseConditional(mlir::raw_indented_ostream &ios,
154-
ArrayRef<const Init *> args,
151+
ArrayRef<Init *> args,
155152
ArrayRef<std::string> argNames) {
156153
ios << "if ";
157154
auto parenScope = ios.scope("(", ") {");
@@ -162,7 +159,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
162159
};
163160

164161
auto parsedArgs =
165-
llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
162+
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
166163
const Record *def = cast<DefInit>(attr)->getDef();
167164
if (def->isSubClassOf("Array"))
168165
return true;
@@ -171,7 +168,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
171168

172169
interleave(
173170
zip(parsedArgs, argNames),
174-
[&](std::tuple<const Init *&, const std::string &> it) {
171+
[&](std::tuple<llvm::Init *&, const std::string &> it) {
175172
const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
176173
std::string parser;
177174
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
@@ -199,7 +196,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
199196
}
200197

201198
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
202-
StringRef builder, ArrayRef<const Init *> args,
199+
StringRef builder, ArrayRef<Init *> args,
203200
ArrayRef<std::string> argNames,
204201
StringRef failure,
205202
mlir::raw_indented_ostream &ios) {
@@ -213,7 +210,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
213210
// Print decls.
214211
std::string lastCType = "";
215212
for (auto [arg, name] : zip(args, argNames)) {
216-
const DefInit *first = dyn_cast<DefInit>(arg);
213+
DefInit *first = dyn_cast<DefInit>(arg);
217214
if (!first)
218215
PrintFatalError("Unexpected type for " + name);
219216
const Record *def = first->getDef();
@@ -254,14 +251,13 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
254251
std::string returnType = getCType(def);
255252
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
256253
<< returnType << "> ";
257-
SmallVector<const Init *> args;
254+
SmallVector<Init *> args;
258255
SmallVector<std::string> argNames;
259256
if (def->isSubClassOf("CompositeBytecode")) {
260-
const DagInit *members = def->getValueAsDag("members");
261-
args = llvm::to_vector(map_range(
262-
members->getArgs(), [](Init *init) { return (const Init *)init; }));
257+
DagInit *members = def->getValueAsDag("members");
258+
args = llvm::to_vector(members->getArgs());
263259
argNames = llvm::to_vector(
264-
map_range(members->getArgNames(), [](const StringInit *init) {
260+
map_range(members->getArgNames(), [](StringInit *init) {
265261
return init->getAsUnquotedString();
266262
}));
267263
} else {
@@ -336,7 +332,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
336332
auto *members = rec->getValueAsDag("members");
337333
for (auto [arg, name] :
338334
llvm::zip(members->getArgs(), members->getArgNames())) {
339-
const DefInit *def = dyn_cast<DefInit>(arg);
335+
DefInit *def = dyn_cast<DefInit>(arg);
340336
assert(def);
341337
const Record *memberRec = def->getDef();
342338
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
@@ -389,7 +385,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
389385
auto *members = memberRec->getValueAsDag("members");
390386
for (auto [arg, argName] :
391387
zip(members->getArgs(), members->getArgNames())) {
392-
const DefInit *def = dyn_cast<DefInit>(arg);
388+
DefInit *def = dyn_cast<DefInit>(arg);
393389
assert(def);
394390
emitPrintHelper(def->getDef(), kind, parent,
395391
argName->getAsUnquotedString(), ios);

0 commit comments

Comments
 (0)