Skip to content

Commit

Permalink
[AsmParser] Rework logic around "region argument parsing"
Browse files Browse the repository at this point in the history
The asm parser had a notional distinction between parsing an
operand (like "%foo" or "%4#3") and parsing a region argument
(which isn't supposed to allow a result number like #3).

Unfortunately the implementation has two problems:

1) It didn't actually check for the result number and reject
   it.  parseRegionArgument and parseOperand were identical.
2) It had a lot of machinery built up around it that paralleled
   operand parsing.  This also was functionally identical, but
   also had some subtle differences (e.g. the parseOptional
   stuff had a different result type).

I thought about just removing all of this, but decided that the
missing error checking was important, so I reimplemented it with
a `allowResultNumber` flag on parseOperand.  This keeps the
codepaths unified and adds the missing error checks.

Differential Revision: https://reviews.llvm.org/D124470
  • Loading branch information
lattner committed Apr 28, 2022
1 parent 6c81b57 commit 5dedf91
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 107 deletions.
11 changes: 7 additions & 4 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,8 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
if (parser.parseLParen() || parser.parseRegionArgument(inductionVariable) ||
if (parser.parseLParen() ||
parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return mlir::failure();

Expand All @@ -1581,8 +1582,9 @@ mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser,

mlir::OpAsmParser::UnresolvedOperand iterateVar, iterateInput;
if (parser.parseKeyword("and") || parser.parseLParen() ||
parser.parseRegionArgument(iterateVar) || parser.parseEqual() ||
parser.parseOperand(iterateInput) || parser.parseRParen() ||
parser.parseOperand(iterateVar, /*allowResultNumber=*/false) ||
parser.parseEqual() || parser.parseOperand(iterateInput) ||
parser.parseRParen() ||
parser.resolveOperand(iterateInput, i1Type, result.operands))
return mlir::failure();

Expand Down Expand Up @@ -1876,7 +1878,8 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
auto &builder = parser.getBuilder();
mlir::OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return mlir::failure();

// Parse loop bounds.
Expand Down
48 changes: 15 additions & 33 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ class AsmParser {
}

/// These are the supported delimiters around operand lists and region
/// argument lists, used by parseOperandList and parseRegionArgumentList.
/// argument lists, used by parseOperandList.
enum class Delimiter {
/// Zero or more operands with no delimiters.
None,
Expand Down Expand Up @@ -1110,22 +1110,27 @@ class OpAsmParser : public AsmParser {
Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
Optional<FunctionType> parsedFnType = llvm::None) = 0;

/// Parse a single operand.
virtual ParseResult parseOperand(UnresolvedOperand &result) = 0;
/// Parse a single SSA value operand name along with a result number if
/// `allowResultNumber` is true.
virtual ParseResult parseOperand(UnresolvedOperand &result,
bool allowResultNumber = true) = 0;

/// Parse a single operand if present.
virtual OptionalParseResult
parseOptionalOperand(UnresolvedOperand &result) = 0;
parseOptionalOperand(UnresolvedOperand &result,
bool allowResultNumber = true) = 0;

/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count.
virtual ParseResult
parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult parseOperandList(
SmallVectorImpl<UnresolvedOperand> &result, int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None, bool allowResultNumber = true) = 0;

ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
Delimiter delimiter) {
return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
Delimiter delimiter,
bool allowResultNumber = true) {
return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter,
allowResultNumber);
}

/// Parse zero or more trailing SSA comma-separated trailing operand
Expand Down Expand Up @@ -1243,29 +1248,6 @@ class OpAsmParser : public AsmParser {
ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;

/// Parse a region argument, this argument is resolved when calling
/// 'parseRegion'.
virtual ParseResult parseRegionArgument(UnresolvedOperand &argument) = 0;

/// Parse zero or more region arguments with a specified surrounding
/// delimiter, and an optional required argument count. Region arguments
/// define new values; so this also checks if values with the same names have
/// not been defined yet.
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) = 0;
virtual ParseResult
parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result,
Delimiter delimiter) {
return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
delimiter);
}

/// Parse a region argument if present.
virtual ParseResult
parseOptionalRegionArgument(UnresolvedOperand &argument) = 0;

//===--------------------------------------------------------------------===//
// Successor Parsing
//===--------------------------------------------------------------------===//
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,8 @@ ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand inductionVariable;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return failure();

// Parse loop bounds.
Expand Down Expand Up @@ -3527,8 +3528,8 @@ ParseResult AffineParallelOp::parse(OpAsmParser &parser,
auto &builder = parser.getBuilder();
auto indexType = builder.getIndexType();
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false) ||
parser.parseEqual() ||
parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
parser.parseKeyword("to") ||
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ static ParseResult parseSwitchOpCases(
parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseRegionArgumentList(defaultOperands) ||
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
Expand All @@ -509,7 +510,8 @@ static ParseResult parseSwitchOpCases(
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,17 +539,17 @@ parseSizeAssignment(OpAsmParser &parser,
MutableArrayRef<OpAsmParser::UnresolvedOperand> indices) {
assert(indices.size() == 3 && "space for three indices expected");
SmallVector<OpAsmParser::UnresolvedOperand, 3> args;
if (parser.parseRegionArgumentList(args, /*requiredOperandCount=*/3,
OpAsmParser::Delimiter::Paren) ||
if (parser.parseOperandList(args, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false) ||
parser.parseKeyword("in") || parser.parseLParen())
return failure();
std::move(args.begin(), args.end(), indices.begin());

for (int i = 0; i < 3; ++i) {
if (i != 0 && parser.parseComma())
return failure();
if (parser.parseRegionArgument(regionSizes[i]) || parser.parseEqual() ||
parser.parseOperand(sizes[i]))
if (parser.parseOperand(regionSizes[i], /*allowResultNumber=*/false) ||
parser.parseEqual() || parser.parseOperand(sizes[i]))
return failure();
}

Expand Down Expand Up @@ -869,7 +869,8 @@ parseAttributions(OpAsmParser &parser, StringRef keyword,
OpAsmParser::UnresolvedOperand arg;
Type type;

if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
parser.parseColonType(type))
return failure();

args.push_back(arg);
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ static ParseResult parseSwitchOpCases(
if (parser.parseColon() || parser.parseSuccessor(destination))
return failure();
if (!parser.parseOptionalLParen()) {
if (parser.parseRegionArgumentList(operands) ||
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
return failure();
}
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ parseOperandList(OpAsmParser &parser, StringRef keyword,
OpAsmParser::UnresolvedOperand arg;
Type type;

if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
if (parser.parseOperand(arg, /*allowResultNumber=*/false) ||
parser.parseColonType(type))
return failure();

args.push_back(arg);
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,8 @@ parseWsLoopControl(OpAsmParser &parser, Region &region,
SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();

size_t numIVs = ivs.size();
Expand Down Expand Up @@ -587,8 +587,8 @@ void printWsLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();
int numIVs = static_cast<int>(ivs.size());
Type loopVarType;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the loop variable followed by type.
OpAsmParser::UnresolvedOperand loopVariable;
Type loopVariableType;
if (parser.parseRegionArgument(loopVariable) ||
if (parser.parseOperand(loopVariable, /*allowResultNumber=*/false) ||
parser.parseColonType(loopVariableType))
return failure();

Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/SCF/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand inductionVariable, lb, ub, step;
// Parse the induction variable followed by '='.
if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
if (parser.parseOperand(inductionVariable, /*allowResultNumber=*/false) ||
parser.parseEqual())
return failure();

// Parse loop bounds.
Expand Down Expand Up @@ -1975,8 +1976,8 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren))
if (parser.parseOperandList(ivs, OpAsmParser::Delimiter::Paren,
/*allowResultNumber=*/false))
return failure();

// Parse loop bounds.
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4698,7 +4698,8 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand laneId;

// Parse predicate operand.
if (parser.parseLParen() || parser.parseRegionArgument(laneId) ||
if (parser.parseLParen() ||
parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
parser.parseRParen())
return failure();

Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/IR/FunctionImplementation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ ParseResult mlir::function_interface_impl::parseFunctionArgumentList(
// Parse argument name if present.
OpAsmParser::UnresolvedOperand argument;
Type argumentType;
if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
!argument.name.empty()) {
auto hadSSAValue = parser.parseOptionalOperand(argument,
/*allowResultNumber=*/false);
if (hadSSAValue.hasValue()) {
if (failed(hadSSAValue.getValue()))
return failure(); // Argument was present but malformed.

// Reject this if the preceding argument was missing a name.
if (argNames.empty() && !argTypes.empty())
return parser.emitError(loc, "expected type instead of SSA identifier");
Expand Down
Loading

0 comments on commit 5dedf91

Please sign in to comment.