-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[OpenMP] Fix atomic compare handling with overloaded operators #141142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang Author: Joseph Huber (jhuber6) ChangesSummary: Full diff: https://github.com/llvm/llvm-project/pull/141142.diff 2 Files Affected:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index f16f841d62edd..a0ad814c366d8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -11762,52 +11762,98 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
X = BO->getLHS();
- auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
- ErrorInfo.Error = ErrorTy::NotABinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
- return false;
- }
-
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
- D = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
+ if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+ switch (Cond->getOpcode()) {
+ case BO_EQ: {
+ C = Cond;
+ D = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case BO_LT:
+ case BO_GT: {
+ E = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+ C = Cond;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ C = Cond;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
- E = BO->getRHS();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
- IsXBinopExpr = false;
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
return false;
}
- break;
- }
- default:
- ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ switch (Call->getOperator()) {
+ case clang::OverloadedOperatorKind::OO_EqualEqual: {
+ C = Call;
+ D = BO->getLHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case clang::OverloadedOperatorKind::OO_Greater:
+ case clang::OverloadedOperatorKind::OO_Less: {
+ E = BO->getRHS();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+ C = Call;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ C = Call;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
+ ErrorInfo.Error = ErrorTy::NotABinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
@@ -11856,53 +11902,99 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
return false;
}
- auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
- if (!Cond) {
- ErrorInfo.Error = ErrorTy::NotABinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
- CO->getCond()->getSourceRange();
- return false;
- }
-
- switch (Cond->getOpcode()) {
- case BO_EQ: {
- C = Cond;
- D = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
+ if (auto *Cond = dyn_cast<BinaryOperator>(CO->getCond())) {
+ switch (Cond->getOpcode()) {
+ case BO_EQ: {
+ C = Cond;
+ D = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case BO_LT:
+ case BO_GT: {
+ E = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+ C = Cond;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ C = Cond;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
return false;
}
- break;
- }
- case BO_LT:
- case BO_GT: {
- E = CO->getTrueExpr();
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
- C = Cond;
- } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
- checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- C = Cond;
- IsXBinopExpr = false;
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond())) {
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
return false;
}
- break;
- }
- default:
- ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ switch (Call->getOperator()) {
+ case clang::OverloadedOperatorKind::OO_EqualEqual: {
+ C = Call;
+ D = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ case clang::OverloadedOperatorKind::OO_Less:
+ case clang::OverloadedOperatorKind::OO_Greater: {
+ E = CO->getTrueExpr();
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+ C = Call;
+ } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+ checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ C = Call;
+ IsXBinopExpr = false;
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ break;
+ }
+ default:
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
+ ErrorInfo.Error = ErrorTy::NotABinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+ CO->getCond()->getSourceRange();
return false;
}
@@ -12062,32 +12154,56 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
X = BO->getLHS();
D = BO->getRHS();
- auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
- if (!Cond) {
+ if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+ C = Cond;
+ if (Cond->getOpcode() != BO_EQ) {
+ ErrorInfo.Error = ErrorTy::NotEQ;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+
+ if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+ E = Cond->getRHS();
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+ E = Cond->getLHS();
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+ return false;
+ }
+ } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+ C = Call;
+ if (Call->getNumArgs() != 2) {
+ ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ if (Call->getOperator() != clang::OverloadedOperatorKind::OO_EqualEqual) {
+ ErrorInfo.Error = ErrorTy::NotEQ;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+
+ if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+ E = Call->getArg(1);
+ } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+ E = Call->getArg(0);
+ } else {
+ ErrorInfo.Error = ErrorTy::InvalidComparison;
+ ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+ ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+ return false;
+ }
+ } else {
ErrorInfo.Error = ErrorTy::NotABinaryOp;
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
return false;
}
- if (Cond->getOpcode() != BO_EQ) {
- ErrorInfo.Error = ErrorTy::NotEQ;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
- return false;
- }
-
- if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
- E = Cond->getRHS();
- } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
- E = Cond->getLHS();
- } else {
- ErrorInfo.Error = ErrorTy::InvalidComparison;
- ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
- ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
- return false;
- }
-
- C = Cond;
if (!S->getElse()) {
ErrorInfo.Error = ErrorTy::NoElse;
diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp
index d492f6ee1e896..c4e240a0ebb4e 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -991,3 +991,34 @@ int mixed() {
// expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
return mixed<int>();
}
+
+#ifdef OMP51
+struct U {};
+struct U operator<(U, U);
+struct U operator>(U, U);
+struct U operator==(U, U);
+
+template <typename T> void templated() {
+ T cx, cv, ce, cd;
+#pragma omp atomic compare capture
+ if (cx == ce) {
+ cx = cd;
+ } else {
+ cv = cx;
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (ce > cx) {
+ cx = ce;
+ }
+ }
+#pragma omp atomic compare capture
+ {
+ cv = cx;
+ if (cx < ce) {
+ cx = ce;
+ }
+ }
+}
+#endif
|
clang/lib/Sema/SemaOpenMP.cpp
Outdated
ErrorInfo.Error = ErrorTy::InvalidComparison; | ||
ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); | ||
ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); | ||
} else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should support this. We can't reliably lower non-trivial type at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't, this will get rejected later when the type is actually instantiated. This just prevents it from rejecting the format outright when it isn't even used, like in the associated test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but you can just check whether they are dependent type and call it a day?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then these fields will be null and we'll trigger an assertion and we won't get any semantic checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, yeah, I'll need some input from @alexey-bataev since I'm not a front end expert.
63ad350
to
a45dc43
Compare
} | ||
} | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need a similar test as a codegen test as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codegen wasn't the issue here, if you manually instantiated it the code would work fine, the issue was just Sema incorrectly rejecting this when there was no type given.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, fair, I guess.
a45dc43
to
f2c18ba
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments.
f2c18ba
to
07caec3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. Fixes #141085
…141142) Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. Fixes llvm#141085
…141142) Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. Fixes llvm#141085
…141142) Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. Fixes llvm#141085
Summary:
When there are overloaded C++ operators in the global namespace the AST
node for these is not a
BinaryExpr
but aCXXOperatorCallExpr
. Modifythe uses to handle this case, basically just treating it as a binary
expression with two arguments.
Fixes #141085