Skip to content

Commit a45dc43

Browse files
committed
[OpenMP] Fix atomic compare handling with overloaded operators
Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments.
1 parent 35ed9a3 commit a45dc43

File tree

2 files changed

+249
-102
lines changed

2 files changed

+249
-102
lines changed

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 218 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -11762,52 +11762,98 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
1176211762

1176311763
X = BO->getLHS();
1176411764

11765-
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
11766-
if (!Cond) {
11767-
ErrorInfo.Error = ErrorTy::NotABinaryOp;
11768-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
11769-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
11770-
return false;
11771-
}
11772-
11773-
switch (Cond->getOpcode()) {
11774-
case BO_EQ: {
11775-
C = Cond;
11776-
D = BO->getRHS();
11777-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11778-
E = Cond->getRHS();
11779-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11780-
E = Cond->getLHS();
11781-
} else {
11782-
ErrorInfo.Error = ErrorTy::InvalidComparison;
11765+
if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
11766+
switch (Cond->getOpcode()) {
11767+
case BO_EQ: {
11768+
C = Cond;
11769+
D = BO->getRHS();
11770+
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11771+
E = Cond->getRHS();
11772+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11773+
E = Cond->getLHS();
11774+
} else {
11775+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11776+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11777+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11778+
return false;
11779+
}
11780+
break;
11781+
}
11782+
case BO_LT:
11783+
case BO_GT: {
11784+
E = BO->getRHS();
11785+
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11786+
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11787+
C = Cond;
11788+
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11789+
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11790+
C = Cond;
11791+
IsXBinopExpr = false;
11792+
} else {
11793+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11794+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11795+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11796+
return false;
11797+
}
11798+
break;
11799+
}
11800+
default:
11801+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
1178311802
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
1178411803
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
1178511804
return false;
1178611805
}
11787-
break;
11788-
}
11789-
case BO_LT:
11790-
case BO_GT: {
11791-
E = BO->getRHS();
11792-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11793-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11794-
C = Cond;
11795-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11796-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11797-
C = Cond;
11798-
IsXBinopExpr = false;
11799-
} else {
11800-
ErrorInfo.Error = ErrorTy::InvalidComparison;
11801-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11802-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11806+
} else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
11807+
if (Call->getNumArgs() != 2) {
11808+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11809+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11810+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
1180311811
return false;
1180411812
}
11805-
break;
11806-
}
11807-
default:
11808-
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11809-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11810-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11813+
switch (Call->getOperator()) {
11814+
case clang::OverloadedOperatorKind::OO_EqualEqual: {
11815+
C = Call;
11816+
D = BO->getLHS();
11817+
if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
11818+
E = Call->getArg(1);
11819+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
11820+
E = Call->getArg(0);
11821+
} else {
11822+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11823+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11824+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11825+
return false;
11826+
}
11827+
break;
11828+
}
11829+
case clang::OverloadedOperatorKind::OO_Greater:
11830+
case clang::OverloadedOperatorKind::OO_Less: {
11831+
E = BO->getRHS();
11832+
if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
11833+
checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
11834+
C = Call;
11835+
} else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
11836+
checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
11837+
C = Call;
11838+
IsXBinopExpr = false;
11839+
} else {
11840+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11841+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11842+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11843+
return false;
11844+
}
11845+
break;
11846+
}
11847+
default:
11848+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11849+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11850+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11851+
return false;
11852+
}
11853+
} else {
11854+
ErrorInfo.Error = ErrorTy::NotABinaryOp;
11855+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
11856+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1181111857
return false;
1181211858
}
1181311859

@@ -11856,53 +11902,99 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
1185611902
return false;
1185711903
}
1185811904

11859-
auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
11860-
if (!Cond) {
11861-
ErrorInfo.Error = ErrorTy::NotABinaryOp;
11862-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11863-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
11864-
CO->getCond()->getSourceRange();
11865-
return false;
11866-
}
11867-
11868-
switch (Cond->getOpcode()) {
11869-
case BO_EQ: {
11870-
C = Cond;
11871-
D = CO->getTrueExpr();
11872-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11873-
E = Cond->getRHS();
11874-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11875-
E = Cond->getLHS();
11876-
} else {
11877-
ErrorInfo.Error = ErrorTy::InvalidComparison;
11905+
if (auto *Cond = dyn_cast<BinaryOperator>(CO->getCond())) {
11906+
switch (Cond->getOpcode()) {
11907+
case BO_EQ: {
11908+
C = Cond;
11909+
D = CO->getTrueExpr();
11910+
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11911+
E = Cond->getRHS();
11912+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11913+
E = Cond->getLHS();
11914+
} else {
11915+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11916+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11917+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11918+
return false;
11919+
}
11920+
break;
11921+
}
11922+
case BO_LT:
11923+
case BO_GT: {
11924+
E = CO->getTrueExpr();
11925+
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11926+
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11927+
C = Cond;
11928+
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11929+
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11930+
C = Cond;
11931+
IsXBinopExpr = false;
11932+
} else {
11933+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11934+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11935+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11936+
return false;
11937+
}
11938+
break;
11939+
}
11940+
default:
11941+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
1187811942
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
1187911943
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
1188011944
return false;
1188111945
}
11882-
break;
11883-
}
11884-
case BO_LT:
11885-
case BO_GT: {
11886-
E = CO->getTrueExpr();
11887-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11888-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11889-
C = Cond;
11890-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11891-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11892-
C = Cond;
11893-
IsXBinopExpr = false;
11894-
} else {
11895-
ErrorInfo.Error = ErrorTy::InvalidComparison;
11896-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11897-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11946+
} else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond())) {
11947+
if (Call->getNumArgs() != 2) {
11948+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11949+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11950+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
1189811951
return false;
1189911952
}
11900-
break;
11901-
}
11902-
default:
11903-
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11904-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11905-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11953+
switch (Call->getOperator()) {
11954+
case clang::OverloadedOperatorKind::OO_EqualEqual: {
11955+
C = Call;
11956+
D = CO->getTrueExpr();
11957+
if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
11958+
E = Call->getArg(1);
11959+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
11960+
E = Call->getArg(0);
11961+
} else {
11962+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11963+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11964+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11965+
return false;
11966+
}
11967+
break;
11968+
}
11969+
case clang::OverloadedOperatorKind::OO_Less:
11970+
case clang::OverloadedOperatorKind::OO_Greater: {
11971+
E = CO->getTrueExpr();
11972+
if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
11973+
checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
11974+
C = Call;
11975+
} else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
11976+
checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
11977+
C = Call;
11978+
IsXBinopExpr = false;
11979+
} else {
11980+
ErrorInfo.Error = ErrorTy::InvalidComparison;
11981+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11982+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11983+
return false;
11984+
}
11985+
break;
11986+
}
11987+
default:
11988+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11989+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
11990+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
11991+
return false;
11992+
}
11993+
} else {
11994+
ErrorInfo.Error = ErrorTy::NotABinaryOp;
11995+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11996+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
11997+
CO->getCond()->getSourceRange();
1190611998
return false;
1190711999
}
1190812000

@@ -12062,32 +12154,56 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
1206212154
X = BO->getLHS();
1206312155
D = BO->getRHS();
1206412156

12065-
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
12066-
if (!Cond) {
12157+
if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
12158+
C = Cond;
12159+
if (Cond->getOpcode() != BO_EQ) {
12160+
ErrorInfo.Error = ErrorTy::NotEQ;
12161+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12162+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12163+
return false;
12164+
}
12165+
12166+
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12167+
E = Cond->getRHS();
12168+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12169+
E = Cond->getLHS();
12170+
} else {
12171+
ErrorInfo.Error = ErrorTy::InvalidComparison;
12172+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12173+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12174+
return false;
12175+
}
12176+
} else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
12177+
C = Call;
12178+
if (Call->getNumArgs() != 2) {
12179+
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
12180+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
12181+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
12182+
return false;
12183+
}
12184+
if (Call->getOperator() != clang::OverloadedOperatorKind::OO_EqualEqual) {
12185+
ErrorInfo.Error = ErrorTy::NotEQ;
12186+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
12187+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
12188+
return false;
12189+
}
12190+
12191+
if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
12192+
E = Call->getArg(1);
12193+
} else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
12194+
E = Call->getArg(0);
12195+
} else {
12196+
ErrorInfo.Error = ErrorTy::InvalidComparison;
12197+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
12198+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
12199+
return false;
12200+
}
12201+
} else {
1206712202
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1206812203
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
1206912204
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1207012205
return false;
1207112206
}
12072-
if (Cond->getOpcode() != BO_EQ) {
12073-
ErrorInfo.Error = ErrorTy::NotEQ;
12074-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12075-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12076-
return false;
12077-
}
12078-
12079-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12080-
E = Cond->getRHS();
12081-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12082-
E = Cond->getLHS();
12083-
} else {
12084-
ErrorInfo.Error = ErrorTy::InvalidComparison;
12085-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12086-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12087-
return false;
12088-
}
12089-
12090-
C = Cond;
1209112207

1209212208
if (!S->getElse()) {
1209312209
ErrorInfo.Error = ErrorTy::NoElse;

clang/test/OpenMP/atomic_messages.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,3 +991,34 @@ int mixed() {
991991
// expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
992992
return mixed<int>();
993993
}
994+
995+
#ifdef OMP51
996+
struct U {};
997+
struct U operator<(U, U);
998+
struct U operator>(U, U);
999+
struct U operator==(U, U);
1000+
1001+
template <typename T> void templated() {
1002+
T cx, cv, ce, cd;
1003+
#pragma omp atomic compare capture
1004+
if (cx == ce) {
1005+
cx = cd;
1006+
} else {
1007+
cv = cx;
1008+
}
1009+
#pragma omp atomic compare capture
1010+
{
1011+
cv = cx;
1012+
if (ce > cx) {
1013+
cx = ce;
1014+
}
1015+
}
1016+
#pragma omp atomic compare capture
1017+
{
1018+
cv = cx;
1019+
if (cx < ce) {
1020+
cx = ce;
1021+
}
1022+
}
1023+
}
1024+
#endif

0 commit comments

Comments
 (0)