Skip to content

Commit d9acca4

Browse files
authored
[OpenMP] Fix atomic compare handling with overloaded operators (llvm#141142) (llvm#2480)
2 parents 5b0c3d9 + d23d92d commit d9acca4

File tree

2 files changed

+128
-65
lines changed

2 files changed

+128
-65
lines changed

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 97 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12159,51 +12159,61 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
1215912159
X = BO->getLHS();
1216012160

1216112161
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
12162-
if (!Cond) {
12162+
auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
12163+
Expr *LHS = nullptr;
12164+
Expr *RHS = nullptr;
12165+
if (Cond) {
12166+
LHS = Cond->getLHS();
12167+
RHS = Cond->getRHS();
12168+
} else if (Call) {
12169+
LHS = Call->getArg(0);
12170+
RHS = Call->getArg(1);
12171+
} else {
1216312172
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1216412173
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
1216512174
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1216612175
return false;
1216712176
}
1216812177

12169-
switch (Cond->getOpcode()) {
12170-
case BO_EQ: {
12171-
C = Cond;
12178+
if ((Cond && Cond->getOpcode() == BO_EQ) ||
12179+
(Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
12180+
C = S->getCond();
1217212181
D = BO->getRHS();
12173-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12174-
E = Cond->getRHS();
12175-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12176-
E = Cond->getLHS();
12182+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
12183+
E = RHS;
12184+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12185+
E = LHS;
1217712186
} else {
1217812187
ErrorInfo.Error = ErrorTy::InvalidComparison;
12179-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12180-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12188+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12189+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
12190+
S->getCond()->getSourceRange();
1218112191
return false;
1218212192
}
12183-
break;
12184-
}
12185-
case BO_LT:
12186-
case BO_GT: {
12193+
} else if ((Cond &&
12194+
(Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
12195+
(Call &&
12196+
(Call->getOperator() == OverloadedOperatorKind::OO_Less ||
12197+
Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
1218712198
E = BO->getRHS();
12188-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
12189-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
12190-
C = Cond;
12191-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
12192-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12193-
C = Cond;
12199+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
12200+
checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
12201+
C = S->getCond();
12202+
} else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
12203+
checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12204+
C = S->getCond();
1219412205
IsXBinopExpr = false;
1219512206
} else {
1219612207
ErrorInfo.Error = ErrorTy::InvalidComparison;
12197-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12198-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12208+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12209+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
12210+
S->getCond()->getSourceRange();
1219912211
return false;
1220012212
}
12201-
break;
12202-
}
12203-
default:
12213+
} else {
1220412214
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
12205-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12206-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12215+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12216+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1220712217
return false;
1220812218
}
1220912219

@@ -12253,52 +12263,64 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
1225312263
}
1225412264

1225512265
auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
12256-
if (!Cond) {
12266+
auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond());
12267+
Expr *LHS = nullptr;
12268+
Expr *RHS = nullptr;
12269+
if (Cond) {
12270+
LHS = Cond->getLHS();
12271+
RHS = Cond->getRHS();
12272+
} else if (Call) {
12273+
LHS = Call->getArg(0);
12274+
RHS = Call->getArg(1);
12275+
} else {
1225712276
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1225812277
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
1225912278
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
1226012279
CO->getCond()->getSourceRange();
1226112280
return false;
1226212281
}
1226312282

12264-
switch (Cond->getOpcode()) {
12265-
case BO_EQ: {
12266-
C = Cond;
12283+
if ((Cond && Cond->getOpcode() == BO_EQ) ||
12284+
(Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
12285+
C = CO->getCond();
1226712286
D = CO->getTrueExpr();
12268-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12269-
E = Cond->getRHS();
12270-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12271-
E = Cond->getLHS();
12287+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
12288+
E = RHS;
12289+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12290+
E = LHS;
1227212291
} else {
1227312292
ErrorInfo.Error = ErrorTy::InvalidComparison;
12274-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12275-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12293+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
12294+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
12295+
CO->getCond()->getSourceRange();
1227612296
return false;
1227712297
}
12278-
break;
12279-
}
12280-
case BO_LT:
12281-
case BO_GT: {
12298+
} else if ((Cond &&
12299+
(Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
12300+
(Call &&
12301+
(Call->getOperator() == OverloadedOperatorKind::OO_Less ||
12302+
Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
12303+
1228212304
E = CO->getTrueExpr();
12283-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
12284-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
12285-
C = Cond;
12286-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
12287-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12288-
C = Cond;
12305+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
12306+
checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
12307+
C = CO->getCond();
12308+
} else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
12309+
checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12310+
C = CO->getCond();
1228912311
IsXBinopExpr = false;
1229012312
} else {
1229112313
ErrorInfo.Error = ErrorTy::InvalidComparison;
12292-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12293-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12314+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
12315+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
12316+
CO->getCond()->getSourceRange();
1229412317
return false;
1229512318
}
12296-
break;
12297-
}
12298-
default:
12319+
} else {
1229912320
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
12300-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12301-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12321+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
12322+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
12323+
CO->getCond()->getSourceRange();
1230212324
return false;
1230312325
}
1230412326

@@ -12459,31 +12481,41 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
1245912481
D = BO->getRHS();
1246012482

1246112483
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
12462-
if (!Cond) {
12484+
auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
12485+
Expr *LHS = nullptr;
12486+
Expr *RHS = nullptr;
12487+
if (Cond) {
12488+
LHS = Cond->getLHS();
12489+
RHS = Cond->getRHS();
12490+
} else if (Call) {
12491+
LHS = Call->getArg(0);
12492+
RHS = Call->getArg(1);
12493+
} else {
1246312494
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1246412495
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
1246512496
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1246612497
return false;
1246712498
}
12468-
if (Cond->getOpcode() != BO_EQ) {
12499+
if ((Cond && Cond->getOpcode() != BO_EQ) ||
12500+
(Call && Call->getOperator() != OverloadedOperatorKind::OO_EqualEqual)) {
1246912501
ErrorInfo.Error = ErrorTy::NotEQ;
12470-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12471-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12502+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12503+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1247212504
return false;
1247312505
}
1247412506

12475-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12476-
E = Cond->getRHS();
12477-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12478-
E = Cond->getLHS();
12507+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
12508+
E = RHS;
12509+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12510+
E = LHS;
1247912511
} else {
1248012512
ErrorInfo.Error = ErrorTy::InvalidComparison;
12481-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12482-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12513+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12514+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1248312515
return false;
1248412516
}
1248512517

12486-
C = Cond;
12518+
C = S->getCond();
1248712519

1248812520
if (!S->getElse()) {
1248912521
ErrorInfo.Error = ErrorTy::NoElse;

clang/test/OpenMP/atomic_messages.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,3 +991,34 @@ int mixed() {
991991
// expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
992992
return mixed<int>();
993993
}
994+
995+
#ifdef OMP51
996+
struct U {};
997+
struct U operator<(U, U);
998+
struct U operator>(U, U);
999+
struct U operator==(U, U);
1000+
1001+
template <typename T> void templated() {
1002+
T cx, cv, ce, cd;
1003+
#pragma omp atomic compare capture
1004+
if (cx == ce) {
1005+
cx = cd;
1006+
} else {
1007+
cv = cx;
1008+
}
1009+
#pragma omp atomic compare capture
1010+
{
1011+
cv = cx;
1012+
if (ce > cx) {
1013+
cx = ce;
1014+
}
1015+
}
1016+
#pragma omp atomic compare capture
1017+
{
1018+
cv = cx;
1019+
if (cx < ce) {
1020+
cx = ce;
1021+
}
1022+
}
1023+
}
1024+
#endif

0 commit comments

Comments
 (0)