Skip to content

Commit dec8055

Browse files
[mlir] Use StringRef::operator== instead of StringRef::equals (NFC) (llvm#91560)
I'm planning to remove StringRef::equals in favor of StringRef::operator==. - StringRef::operator==/!= outnumber StringRef::equals by a factor of 10 under mlir/ in terms of their usage. - The elimination of StringRef::equals brings StringRef closer to std::string_view, which has operator== but not equals. - S == "foo" is more readable than S.equals("foo"), especially for !Long.Expression.equals("str") vs Long.Expression != "str".
1 parent fd1bd53 commit dec8055

File tree

14 files changed

+39
-47
lines changed

14 files changed

+39
-47
lines changed

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
926926
static bool isDefinedByCallTo(Value value, StringRef functionName) {
927927
assert(isa<LLVM::LLVMPointerType>(value.getType()));
928928
if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
929-
return defOp.getCallee()->equals(functionName);
929+
return *defOp.getCallee() == functionName;
930930
return false;
931931
}
932932

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
4242
static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";
4343

4444
static NVVM::MMAFrag convertOperand(StringRef operandName) {
45-
if (operandName.equals("AOp"))
45+
if (operandName == "AOp")
4646
return NVVM::MMAFrag::a;
47-
if (operandName.equals("BOp"))
47+
if (operandName == "BOp")
4848
return NVVM::MMAFrag::b;
49-
if (operandName.equals("COp"))
49+
if (operandName == "COp")
5050
return NVVM::MMAFrag::c;
5151
llvm_unreachable("Unknown operand name");
5252
}
@@ -55,8 +55,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
5555
if (type.getElementType().isF16())
5656
return NVVM::MMATypes::f16;
5757
if (type.getElementType().isF32())
58-
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
59-
: NVVM::MMATypes::tf32;
58+
return type.getOperand() == "COp" ? NVVM::MMATypes::f32
59+
: NVVM::MMATypes::tf32;
6060

6161
if (type.getElementType().isSignedInteger(8))
6262
return NVVM::MMATypes::s8;
@@ -99,15 +99,15 @@ struct WmmaLoadOpToNVVMLowering
9999
NVVM::MMATypes eltype = getElementType(retType);
100100
// NVVM intrinsics require to give mxnxk dimensions, infer the missing
101101
// dimension based on the valid intrinsics available.
102-
if (retType.getOperand().equals("AOp")) {
102+
if (retType.getOperand() == "AOp") {
103103
m = retTypeShape[0];
104104
k = retTypeShape[1];
105105
n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
106-
} else if (retType.getOperand().equals("BOp")) {
106+
} else if (retType.getOperand() == "BOp") {
107107
k = retTypeShape[0];
108108
n = retTypeShape[1];
109109
m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
110-
} else if (retType.getOperand().equals("COp")) {
110+
} else if (retType.getOperand() == "COp") {
111111
m = retTypeShape[0];
112112
n = retTypeShape[1];
113113
k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
261261
template <typename OpTy>
262262
static bool isTensorOp(OpTy xferOp) {
263263
if (isa<RankedTensorType>(xferOp.getShapedType())) {
264-
if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
264+
if (xferOp.getOperationName() == TransferWriteOp::getOperationName()) {
265265
// TransferWriteOps on tensors have a result.
266266
assert(xferOp->getNumResults() > 0);
267267
}

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -3585,20 +3585,18 @@ ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
35853585
parser.resolveOperands(mapOperands, indexTy, result.operands))
35863586
return failure();
35873587

3588-
if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3588+
if (readOrWrite != "read" && readOrWrite != "write")
35893589
return parser.emitError(parser.getNameLoc(),
35903590
"rw specifier has to be 'read' or 'write'");
3591-
result.addAttribute(
3592-
AffinePrefetchOp::getIsWriteAttrStrName(),
3593-
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3591+
result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3592+
parser.getBuilder().getBoolAttr(readOrWrite == "write"));
35943593

3595-
if (!cacheType.equals("data") && !cacheType.equals("instr"))
3594+
if (cacheType != "data" && cacheType != "instr")
35963595
return parser.emitError(parser.getNameLoc(),
35973596
"cache type has to be 'data' or 'instr'");
35983597

3599-
result.addAttribute(
3600-
AffinePrefetchOp::getIsDataCacheAttrStrName(),
3601-
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3598+
result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3599+
parser.getBuilder().getBoolAttr(cacheType == "data"));
36023600

36033601
return success();
36043602
}

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+5-8
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ LogicalResult
152152
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
153153
ArrayRef<int64_t> shape, Type elementType,
154154
StringRef operand) {
155-
if (!operand.equals("AOp") && !operand.equals("BOp") &&
156-
!operand.equals("COp"))
155+
if (operand != "AOp" && operand != "BOp" && operand != "COp")
157156
return emitError() << "operand expected to be one of AOp, BOp or COp";
158157

159158
if (shape.size() != 2)
@@ -1941,8 +1940,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
19411940
return emitError(
19421941
"expected source memref most minor dim must have unit stride");
19431942

1944-
if (!operand.equals("AOp") && !operand.equals("BOp") &&
1945-
!operand.equals("COp"))
1943+
if (operand != "AOp" && operand != "BOp" && operand != "COp")
19461944
return emitError("only AOp, BOp and COp can be loaded");
19471945

19481946
return success();
@@ -1962,7 +1960,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
19621960
return emitError(
19631961
"expected destination memref most minor dim must have unit stride");
19641962

1965-
if (!srcMatrixType.getOperand().equals("COp"))
1963+
if (srcMatrixType.getOperand() != "COp")
19661964
return emitError(
19671965
"expected the operand matrix being stored to have 'COp' operand type");
19681966

@@ -1980,9 +1978,8 @@ LogicalResult SubgroupMmaComputeOp::verify() {
19801978
opTypes.push_back(llvm::cast<MMAMatrixType>(getOpB().getType()));
19811979
opTypes.push_back(llvm::cast<MMAMatrixType>(getOpC().getType()));
19821980

1983-
if (!opTypes[A].getOperand().equals("AOp") ||
1984-
!opTypes[B].getOperand().equals("BOp") ||
1985-
!opTypes[C].getOperand().equals("COp"))
1981+
if (opTypes[A].getOperand() != "AOp" || opTypes[B].getOperand() != "BOp" ||
1982+
opTypes[C].getOperand() != "COp")
19861983
return emitError("operands must be in the order AOp, BOp, COp");
19871984

19881985
ArrayRef<int64_t> aShape, bShape, cShape;

mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
131131
/// Compares two keys.
132132
bool operator==(const Key &other) const {
133133
if (isIdentified())
134-
return other.isIdentified() &&
135-
other.getIdentifier().equals(getIdentifier());
134+
return other.isIdentified() && other.getIdentifier() == getIdentifier();
136135

137136
return !other.isIdentified() && other.isPacked() == isPacked() &&
138137
other.getTypeList() == getTypeList();

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -1742,20 +1742,18 @@ ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
17421742
parser.resolveOperands(indexInfo, indexTy, result.operands))
17431743
return failure();
17441744

1745-
if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1745+
if (readOrWrite != "read" && readOrWrite != "write")
17461746
return parser.emitError(parser.getNameLoc(),
17471747
"rw specifier has to be 'read' or 'write'");
1748-
result.addAttribute(
1749-
PrefetchOp::getIsWriteAttrStrName(),
1750-
parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1748+
result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1749+
parser.getBuilder().getBoolAttr(readOrWrite == "write"));
17511750

1752-
if (!cacheType.equals("data") && !cacheType.equals("instr"))
1751+
if (cacheType != "data" && cacheType != "instr")
17531752
return parser.emitError(parser.getNameLoc(),
17541753
"cache type has to be 'data' or 'instr'");
17551754

1756-
result.addAttribute(
1757-
PrefetchOp::getIsDataCacheAttrStrName(),
1758-
parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1755+
result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1756+
parser.getBuilder().getBoolAttr(cacheType == "data"));
17591757

17601758
return success();
17611759
}

mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
8989
auto loc = parser.getCurrentLocation();
9090
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
9191
"expected valid level property (e.g. nonordered, nonunique or high)")
92-
if (strVal.equals(toPropString(LevelPropNonDefault::Nonunique))) {
92+
if (strVal == toPropString(LevelPropNonDefault::Nonunique)) {
9393
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
94-
} else if (strVal.equals(toPropString(LevelPropNonDefault::Nonordered))) {
94+
} else if (strVal == toPropString(LevelPropNonDefault::Nonordered)) {
9595
*properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
96-
} else if (strVal.equals(toPropString(LevelPropNonDefault::SoA))) {
96+
} else if (strVal == toPropString(LevelPropNonDefault::SoA)) {
9797
*properties |= static_cast<uint64_t>(LevelPropNonDefault::SoA);
9898
} else {
9999
parser.emitError(loc, "unknown level property: ") << strVal;

mlir/lib/IR/AttributeDetail.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage {
261261
// Check to see if this storage represents a splat. If it doesn't then
262262
// combine the hash for the data starting with the first non splat element.
263263
for (size_t i = 1, e = data.size(); i != e; i++)
264-
if (!firstElt.equals(data[i]))
264+
if (firstElt != data[i])
265265
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
266266

267267
// Otherwise, this is a splat so just return the hash of the first element.

mlir/lib/TableGen/Builder.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
5252
// Initialize the parameters of the builder.
5353
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
5454
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
55-
if (!defInit || !defInit->getDef()->getName().equals("ins"))
55+
if (!defInit || defInit->getDef()->getName() != "ins")
5656
PrintFatalError(def->getLoc(), "expected 'ins' in builders");
5757

5858
bool seenDefaultValue = false;

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
9393
return failure();
9494

9595
// Handle function entry count metadata.
96-
if (name->getString().equals("function_entry_count")) {
96+
if (name->getString() == "function_entry_count") {
9797

9898
// TODO support function entry count metadata with GUID fields.
9999
if (node->getNumOperands() != 2)
@@ -111,7 +111,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
111111
<< "expected function_entry_count to be attached to a function";
112112
}
113113

114-
if (!name->getString().equals("branch_weights"))
114+
if (name->getString() != "branch_weights")
115115
return failure();
116116

117117
// Handle branch weights metadata.

mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
413413
// of inner-op), then we can print the entire region in a succinct way.
414414
// Here we assume that the prototype of "test.special.op" can be trivially
415415
// derived while parsing it back.
416-
if (innerOp.getName().getStringRef().equals("test.special.op")) {
416+
if (innerOp.getName().getStringRef() == "test.special.op") {
417417
p << " start test.special.op end";
418418
} else {
419419
p << " (";

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ static void collectAllDefs(StringRef selectedDialect,
5050
} else {
5151
// Otherwise, generate the defs that belong to the selected dialect.
5252
auto dialectDefs = llvm::make_filter_range(defs, [&](const auto &def) {
53-
return def.getDialect().getName().equals(selectedDialect);
53+
return def.getDialect().getName() == selectedDialect;
5454
});
5555
resultDefs.assign(dialectDefs.begin(), dialectDefs.end());
5656
}

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
457457
std::string sanitizedName = sanitizeName(namedAttr.name);
458458

459459
// Unit attributes are handled specially.
460-
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
460+
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
461461
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
462462
namedAttr.name);
463463
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
@@ -668,7 +668,7 @@ populateBuilderLinesAttr(const Operator &op,
668668
continue;
669669

670670
// Unit attributes are handled specially.
671-
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
671+
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
672672
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
673673
attribute->name, argNames[i]));
674674
continue;

0 commit comments

Comments
 (0)