Skip to content

Commit f2c18ba

Browse files
committed
[OpenMP] Fix atomic compare handling with overloaded operators
Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments.
1 parent b4bc8c6 commit f2c18ba

File tree

2 files changed

+126
-67
lines changed

2 files changed

+126
-67
lines changed

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 95 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -11763,51 +11763,61 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
1176311763
X = BO->getLHS();
1176411764

1176511765
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
11766-
if (!Cond) {
11766+
auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
11767+
Expr *LHS = nullptr;
11768+
Expr *RHS = nullptr;
11769+
if (Cond) {
11770+
LHS = Cond->getLHS();
11771+
RHS = Cond->getRHS();
11772+
} else if (Call) {
11773+
LHS = Call->getArg(0);
11774+
RHS = Call->getArg(1);
11775+
} else {
1176711776
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1176811777
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
1176911778
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1177011779
return false;
1177111780
}
1177211781

11773-
switch (Cond->getOpcode()) {
11774-
case BO_EQ: {
11775-
C = Cond;
11782+
if ((Cond && Cond->getOpcode() == BO_EQ) ||
11783+
(Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
11784+
C = S->getCond();
1177611785
D = BO->getRHS();
11777-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11778-
E = Cond->getRHS();
11779-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11780-
E = Cond->getLHS();
11786+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
11787+
E = RHS;
11788+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
11789+
E = LHS;
1178111790
} else {
1178211791
ErrorInfo.Error = ErrorTy::InvalidComparison;
11783-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11784-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11792+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
11793+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
11794+
S->getCond()->getSourceRange();
1178511795
return false;
1178611796
}
11787-
break;
11788-
}
11789-
case BO_LT:
11790-
case BO_GT: {
11797+
} else if ((Cond &&
11798+
(Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
11799+
(Call &&
11800+
(Call->getOperator() == OverloadedOperatorKind::OO_Less ||
11801+
Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
1179111802
E = BO->getRHS();
11792-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11793-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11794-
C = Cond;
11795-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11796-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11797-
C = Cond;
11803+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
11804+
checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
11805+
C = S->getCond();
11806+
} else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
11807+
checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
11808+
C = S->getCond();
1179811809
IsXBinopExpr = false;
1179911810
} else {
1180011811
ErrorInfo.Error = ErrorTy::InvalidComparison;
11801-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11802-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11812+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
11813+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
11814+
S->getCond()->getSourceRange();
1180311815
return false;
1180411816
}
11805-
break;
11806-
}
11807-
default:
11817+
} else {
1180811818
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11809-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11810-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11819+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
11820+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1181111821
return false;
1181211822
}
1181311823

@@ -11857,52 +11867,60 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
1185711867
}
1185811868

1185911869
auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
11860-
if (!Cond) {
11870+
auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond());
11871+
Expr *LHS = nullptr;
11872+
Expr *RHS = nullptr;
11873+
if (Cond) {
11874+
LHS = Cond->getLHS();
11875+
RHS = Cond->getRHS();
11876+
} else if (Call) {
11877+
LHS = Call->getArg(0);
11878+
RHS = Call->getArg(1);
11879+
} else {
1186111880
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1186211881
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11863-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
11864-
CO->getCond()->getSourceRange();
11882+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
1186511883
return false;
1186611884
}
1186711885

11868-
switch (Cond->getOpcode()) {
11869-
case BO_EQ: {
11870-
C = Cond;
11886+
if ((Cond && Cond->getOpcode() == BO_EQ) ||
11887+
(Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
11888+
C = CO->getCond();
1187111889
D = CO->getTrueExpr();
11872-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
11873-
E = Cond->getRHS();
11874-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11875-
E = Cond->getLHS();
11890+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
11891+
E = RHS;
11892+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
11893+
E = LHS;
1187611894
} else {
1187711895
ErrorInfo.Error = ErrorTy::InvalidComparison;
11878-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11879-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11896+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11897+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
1188011898
return false;
1188111899
}
11882-
break;
11883-
}
11884-
case BO_LT:
11885-
case BO_GT: {
11900+
} else if ((Cond &&
11901+
(Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
11902+
(Call &&
11903+
(Call->getOperator() == OverloadedOperatorKind::OO_Less ||
11904+
Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
11905+
1188611906
E = CO->getTrueExpr();
11887-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
11888-
checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
11889-
C = Cond;
11890-
} else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
11891-
checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
11892-
C = Cond;
11907+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
11908+
checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
11909+
C = CO->getCond();
11910+
} else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
11911+
checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
11912+
C = CO->getCond();
1189311913
IsXBinopExpr = false;
1189411914
} else {
1189511915
ErrorInfo.Error = ErrorTy::InvalidComparison;
11896-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11897-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11916+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11917+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
1189811918
return false;
1189911919
}
11900-
break;
11901-
}
11902-
default:
11920+
} else {
1190311921
ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
11904-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
11905-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
11922+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
11923+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange();
1190611924
return false;
1190711925
}
1190811926

@@ -12063,31 +12081,41 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
1206312081
D = BO->getRHS();
1206412082

1206512083
auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
12066-
if (!Cond) {
12084+
auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
12085+
Expr *LHS = nullptr;
12086+
Expr *RHS = nullptr;
12087+
if (Cond) {
12088+
LHS = Cond->getLHS();
12089+
RHS = Cond->getRHS();
12090+
} else if (Call) {
12091+
LHS = Call->getArg(0);
12092+
RHS = Call->getArg(1);
12093+
} else {
1206712094
ErrorInfo.Error = ErrorTy::NotABinaryOp;
1206812095
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
1206912096
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1207012097
return false;
1207112098
}
12072-
if (Cond->getOpcode() != BO_EQ) {
12099+
if ((Cond && Cond->getOpcode() != BO_EQ) ||
12100+
(Call && Call->getOperator() != OverloadedOperatorKind::OO_EqualEqual)) {
1207312101
ErrorInfo.Error = ErrorTy::NotEQ;
12074-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12075-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12102+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12103+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1207612104
return false;
1207712105
}
1207812106

12079-
if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
12080-
E = Cond->getRHS();
12081-
} else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
12082-
E = Cond->getLHS();
12107+
if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
12108+
E = RHS;
12109+
} else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
12110+
E = LHS;
1208312111
} else {
1208412112
ErrorInfo.Error = ErrorTy::InvalidComparison;
12085-
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
12086-
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
12113+
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
12114+
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
1208712115
return false;
1208812116
}
1208912117

12090-
C = Cond;
12118+
C = S->getCond();
1209112119

1209212120
if (!S->getElse()) {
1209312121
ErrorInfo.Error = ErrorTy::NoElse;

clang/test/OpenMP/atomic_messages.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,3 +991,34 @@ int mixed() {
991991
// expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
992992
return mixed<int>();
993993
}
994+
995+
#ifdef OMP51
996+
struct U {};
997+
struct U operator<(U, U);
998+
struct U operator>(U, U);
999+
struct U operator==(U, U);
1000+
1001+
template <typename T> void templated() {
1002+
T cx, cv, ce, cd;
1003+
#pragma omp atomic compare capture
1004+
if (cx == ce) {
1005+
cx = cd;
1006+
} else {
1007+
cv = cx;
1008+
}
1009+
#pragma omp atomic compare capture
1010+
{
1011+
cv = cx;
1012+
if (ce > cx) {
1013+
cx = ce;
1014+
}
1015+
}
1016+
#pragma omp atomic compare capture
1017+
{
1018+
cv = cx;
1019+
if (cx < ce) {
1020+
cx = ce;
1021+
}
1022+
}
1023+
}
1024+
#endif

0 commit comments

Comments
 (0)