-
Notifications
You must be signed in to change notification settings - Fork 14k
[flang][NFC] Move new code to right place #144551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Some new code was added to flang/Semantics that only depends on facilities in flang/Evaluate. Move it into Evaluate and clean up some minor stylistic problems.
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-flang-openmp Author: Peter Klausler (klausler) ChangesSome new code was added to flang/Semantics that only depends on facilities in flang/Evaluate. Move it into Evaluate and clean up some minor stylistic problems. Patch is 32.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144551.diff 7 Files Affected:
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 1959d5f3a5899..e04621f71f9a7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
}
+// Checks whether the symbol on the LHS is present in the RHS expression.
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
+
+namespace operation {
+
+enum class Operator {
+ Unknown,
+ Add,
+ And,
+ Associated,
+ Call,
+ Constant,
+ Convert,
+ Div,
+ Eq,
+ Eqv,
+ False,
+ Ge,
+ Gt,
+ Identity,
+ Intrinsic,
+ Le,
+ Lt,
+ Max,
+ Min,
+ Mul,
+ Ne,
+ Neqv,
+ Not,
+ Or,
+ Pow,
+ Resize, // Convert within the same TypeCategory
+ Sub,
+ True,
+};
+
+std::string ToString(Operator op);
+
+template <typename... Ts, int Kind>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
+ switch (op.derived().logicalOperator) {
+ case common::LogicalOperator::And:
+ return Operator::And;
+ case common::LogicalOperator::Or:
+ return Operator::Or;
+ case common::LogicalOperator::Eqv:
+ return Operator::Eqv;
+ case common::LogicalOperator::Neqv:
+ return Operator::Neqv;
+ case common::LogicalOperator::Not:
+ return Operator::Not;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
+ switch (op.derived().opr) {
+ case common::RelationalOperator::LT:
+ return Operator::Lt;
+ case common::RelationalOperator::LE:
+ return Operator::Le;
+ case common::RelationalOperator::EQ:
+ return Operator::Eq;
+ case common::RelationalOperator::NE:
+ return Operator::Ne;
+ case common::RelationalOperator::GE:
+ return Operator::Ge;
+ case common::RelationalOperator::GT:
+ return Operator::Gt;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+ return Operator::Add;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+ return Operator::Sub;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+ return Operator::Mul;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+ return Operator::Div;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, common::TypeCategory C, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+ if constexpr (C == T::category) {
+ return Operator::Resize;
+ } else {
+ return Operator::Convert;
+ }
+}
+
+template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+ return Operator::Constant;
+}
+
+template <typename T> Operator OperationCode(const T &) {
+ return Operator::Unknown;
+}
+
+Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+
+} // namespace operation
+
+// Return information about the top-level operation (ignoring parentheses):
+// the operation code and the list of arguments.
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr);
+
+// Check if expr is same as x, or a sequence of Convert operations on x.
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
+
+// Strip away any top-level Convert operations (if any exist) and return
+// the input value. A ComplexConstructor(x, 0) is also considered as a
+// convert operation.
+// If the input is not Operation, Designator, FunctionRef or Constant,
+// it returns std::nullopt.
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
+
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index 69375a83dec25..f3cfa9b99fb4d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
// Check for ambiguous USE associations
bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
-// Checks whether the symbol on the LHS is present in the RHS expression.
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
-
-namespace operation {
-
-enum class Operator {
- Unknown,
- Add,
- And,
- Associated,
- Call,
- Constant,
- Convert,
- Div,
- Eq,
- Eqv,
- False,
- Ge,
- Gt,
- Identity,
- Intrinsic,
- Le,
- Lt,
- Max,
- Min,
- Mul,
- Ne,
- Neqv,
- Not,
- Or,
- Pow,
- Resize, // Convert within the same TypeCategory
- Sub,
- True,
-};
-
-std::string ToString(Operator op);
-
-template <typename... Ts, int Kind>
-Operator OperationCode(
- const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
- switch (op.derived().logicalOperator) {
- case common::LogicalOperator::And:
- return Operator::And;
- case common::LogicalOperator::Or:
- return Operator::Or;
- case common::LogicalOperator::Eqv:
- return Operator::Eqv;
- case common::LogicalOperator::Neqv:
- return Operator::Neqv;
- case common::LogicalOperator::Not:
- return Operator::Not;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
- switch (op.derived().opr) {
- case common::RelationalOperator::LT:
- return Operator::Lt;
- case common::RelationalOperator::LE:
- return Operator::Le;
- case common::RelationalOperator::EQ:
- return Operator::Eq;
- case common::RelationalOperator::NE:
- return Operator::Ne;
- case common::RelationalOperator::GE:
- return Operator::Ge;
- case common::RelationalOperator::GT:
- return Operator::Gt;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
- return Operator::Add;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
- return Operator::Sub;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
- return Operator::Mul;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
- return Operator::Div;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
- if constexpr (C == T::category) {
- return Operator::Resize;
- } else {
- return Operator::Convert;
- }
-}
-
-template <typename T> //
-Operator OperationCode(const evaluate::Constant<T> &x) {
- return Operator::Constant;
-}
-
-template <typename T> //
-Operator OperationCode(const T &) {
- return Operator::Unknown;
-}
-
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
-
-} // namespace operation
-
-/// Return information about the top-level operation (ignoring parentheses):
-/// the operation code and the list of arguments.
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
- const SomeExpr &expr);
-
-/// Check if expr is same as x, or a sequence of Convert operations on x.
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
-
-/// Strip away any top-level Convert operations (if any exist) and return
-/// the input value. A ComplexConstructor(x, 0) is also considered as a
-/// convert operation.
-/// If the input is not Operation, Designator, FunctionRef or Constant,
-/// it returns std::nullopt.
-MaybeExpr GetConvertInput(const SomeExpr &x);
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_TOOLS_H_
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 222c32a9c332e..68838564f87ba 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -13,6 +13,7 @@
#include "flang/Evaluate/traverse.h"
#include "flang/Parser/message.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSwitch.h"
#include <algorithm>
#include <variant>
@@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
}
}
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
+ if (lhs && rhs) {
+ if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
+ const Symbol &first{*lhsSymbols.front()};
+ for (const Symbol &symbol : GetSymbolVector(*rhs)) {
+ if (first == symbol) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+namespace operation {
+template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
+ return AsGenericExpr(common::Clone(x));
+}
+
+template <bool IgnoreResizingConverts>
+struct ArgumentExtractor
+ : public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+ std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
+ using Arguments = std::vector<Expr<SomeType>>;
+ using Result = std::pair<operation::Operator, Arguments>;
+ using Base =
+ Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
+ static constexpr auto IgnoreResizes{IgnoreResizingConverts};
+ static constexpr auto Logical{common::TypeCategory::Logical};
+ ArgumentExtractor() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <int Kind>
+ Result operator()(const Constant<Type<Logical, Kind>> &x) const {
+ if (const auto &val{x.GetScalarValue()}) {
+ return val->IsTrue()
+ ? std::make_pair(operation::Operator::True, Arguments{})
+ : std::make_pair(operation::Operator::False, Arguments{});
+ }
+ return Default();
+ }
+
+ template <typename R> Result operator()(const FunctionRef<R> &x) const {
+ Result result{operation::OperationCode(x.proc()), {}};
+ for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+ if (auto *e{x.UnwrapArgExpr(i)}) {
+ result.second.push_back(*e);
+ }
+ }
+ return result;
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore top-level parentheses.
+ return (*this)(x.template operand<0>());
+ }
+ if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
+ // Ignore conversions within the same category.
+ // Atomic operations on int(kind=1) may be implicitly widened
+ // to int(kind=4) for example.
+ return (*this)(x.template operand<0>());
+ } else {
+ return std::make_pair(operation::OperationCode(x),
+ OperationArgs(x, std::index_sequence_for<Os...>{}));
+ }
+ }
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ // There shouldn't be any combining needed, since we're stopping the
+ // traversal at the top-level operation, but implement one that picks
+ // the first non-empty result.
+ if constexpr (sizeof...(Rs) == 0) {
+ return std::move(result);
+ } else {
+ if (!result.second.empty()) {
+ return std::move(result);
+ } else {
+ return Combine(std::move(results)...);
+ }
+ }
+ }
+
+private:
+ template <typename D, typename R, typename... Os, size_t... Is>
+ Arguments OperationArgs(
+ const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
+ return Arguments{Expr<SomeType>(x.template operand<Is>())...};
+ }
+};
+} // namespace operation
+
+std::string operation::ToString(operation::Operator op) {
+ switch (op) {
+ case Operator::Unknown:
+ return "??";
+ case Operator::Add:
+ return "+";
+ case Operator::And:
+ return "AND";
+ case Operator::Associated:
+ return "ASSOCIATED";
+ case Operator::Call:
+ return "function-call";
+ case Operator::Constant:
+ return "constant";
+ case Operator::Convert:
+ return "type-conversion";
+ case Operator::Div:
+ return "/";
+ case Operator::Eq:
+ return "==";
+ case Operator::Eqv:
+ return "EQV";
+ case Operator::False:
+ return ".FALSE.";
+ case Operator::Ge:
+ return ">=";
+ case Operator::Gt:
+ return ">";
+ case Operator::Identity:
+ return "identity";
+ case Operator::Intrinsic:
+ return "intrinsic";
+ case Operator::Le:
+ return "<=";
+ case Operator::Lt:
+ return "<";
+ case Operator::Max:
+ return "MAX";
+ case Operator::Min:
+ return "MIN";
+ case Operator::Mul:
+ return "*";
+ case Operator::Ne:
+ return "/=";
+ case Operator::Neqv:
+ return "NEQV/EOR";
+ case Operator::Not:
+ return "NOT";
+ case Operator::Or:
+ return "OR";
+ case Operator::Pow:
+ return "**";
+ case Operator::Resize:
+ return "resize";
+ case Operator::Sub:
+ return "-";
+ case Operator::True:
+ return ".TRUE.";
+ }
+ llvm_unreachable("Unhandler operator");
+}
+
+operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
+ Operator code{llvm::StringSwitch<Operator>(proc.GetName())
+ .Case("associated", Operator::Associated)
+ .Case("min", Operator::Min)
+ .Case("max", Operator::Max)
+ .Case("iand", Operator::And)
+ .Case("ior", Operator::Or)
+ .Case("ieor", Operator::Neqv)
+ .Default(Operator::Call)};
+ if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+ return Operator::Intrinsic;
+ }
+ return code;
+}
+
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr) {
+ return operation::ArgumentExtractor<true>{}(expr);
+}
+
+namespace operation {
+struct ConvertCollector
+ : public Traverse<ConvertCollector,
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
+ false> {
+ using Result =
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
+ using Base = Traverse<ConvertCollector, Result, false>;
+ ConvertCollector() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const FunctionRef<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore parentheses.
+ return (*this)(x.template operand<0>());
+ } else if constexpr (is_convert_v<D>) {
+ // Convert should always have a typed result, so it should be safe to
+ // dereference x.GetType().
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else if constexpr (is_complex_constructor_v<D>) {
+ // This is a conversion iff the imaginary operand is 0.
+ if (IsZero(x.template operand<1>())) {
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ Result v(std::move(result));
+ auto setValue{[](std::optional<Expr<SomeType>> &x,
+ std::optional<Expr<SomeType>> &&y) {
+ assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+ if (!x.has_value()) {
+ x = std::move(y);
+ }
+ }};
+ auto moveAppend{[](auto &accum, auto &&other) {
+ for (auto &&s : other) {
+ accum.push_back(std::move(s));
+ }
+ }};
+ (setValue(v.first, std::move(results).first), ...);
+ (moveAppend(v.second, std::move(results).second), ...);
+ return v;
+ }
+
+private:
+ template <typename A> static bool IsZero(const A &x) { return false; }
+ template <typename T> static bool IsZero(const Expr<T> &x) {
+ return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+ }
+ template <typename T> static bool IsZero(const Constant<T> &x) {
+ if (auto &&maybeScalar{x.GetScalarValue()}) {
+ return maybeScalar->IsZero();
+ } else {
+ return false;
+ }
+ }
+
+ template <typename T> struct is_convert {
+ static constexpr bool value{false};
+ };
+ template <typename T, common::TypeCategory C>
+ struct is_convert<Convert<T, C>> {
+ static constexpr bool value{true};
+ };
+ template <int K> struct is_convert<ComplexComponent<K>> {
+ // Conversion from complex to real.
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_convert_v{is_convert<T>::value};
+
+ template <typename T> struct is_complex_constructor {
+ static constexpr bool value{false};
+ };
+ template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_complex_constructor_v{
+ is_complex_constructor<T>::value};
+};
+} // namespace operation
+
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
+ // This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
+ return operation::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
+ // Check if expr is same as x, or a sequence of Convert operations on x.
+ if (expr == x) {
+ return true;
+ } else if (auto maybe{GetConvertInput(expr)}) {
+ return *maybe == x;
+ } else {
+ return false;
+ }
+}
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 69e9c53baa740..3ef3330cba2d6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
firOpBuilder.setInsertionPointToStart(&block);
if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
- if (Fortran::semantics::CheckForSymbolMatch(
+ if (Fortran::evaluate::CheckForSymbolMatch(
Fortran::semantics::GetExpr(stmt2Var),
Fortran::semantics::GetExpr(stmt2Expr))) {
// Atomic capture construct is of the form [capture-stmt, update-stmt]
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 82673f0948a5b..0acfd5b0a2534 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc,
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must ex...
[truncated]
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Peter Klausler (klausler) ChangesSome new code was added to flang/Semantics that only depends on facilities in flang/Evaluate. Move it into Evaluate and clean up some minor stylistic problems. Patch is 32.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144551.diff 7 Files Affected:
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 1959d5f3a5899..e04621f71f9a7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
}
+// Checks whether the symbol on the LHS is present in the RHS expression.
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
+
+namespace operation {
+
+enum class Operator {
+ Unknown,
+ Add,
+ And,
+ Associated,
+ Call,
+ Constant,
+ Convert,
+ Div,
+ Eq,
+ Eqv,
+ False,
+ Ge,
+ Gt,
+ Identity,
+ Intrinsic,
+ Le,
+ Lt,
+ Max,
+ Min,
+ Mul,
+ Ne,
+ Neqv,
+ Not,
+ Or,
+ Pow,
+ Resize, // Convert within the same TypeCategory
+ Sub,
+ True,
+};
+
+std::string ToString(Operator op);
+
+template <typename... Ts, int Kind>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
+ switch (op.derived().logicalOperator) {
+ case common::LogicalOperator::And:
+ return Operator::And;
+ case common::LogicalOperator::Or:
+ return Operator::Or;
+ case common::LogicalOperator::Eqv:
+ return Operator::Eqv;
+ case common::LogicalOperator::Neqv:
+ return Operator::Neqv;
+ case common::LogicalOperator::Not:
+ return Operator::Not;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
+ switch (op.derived().opr) {
+ case common::RelationalOperator::LT:
+ return Operator::Lt;
+ case common::RelationalOperator::LE:
+ return Operator::Le;
+ case common::RelationalOperator::EQ:
+ return Operator::Eq;
+ case common::RelationalOperator::NE:
+ return Operator::Ne;
+ case common::RelationalOperator::GE:
+ return Operator::Ge;
+ case common::RelationalOperator::GT:
+ return Operator::Gt;
+ }
+ return Operator::Unknown;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
+ return Operator::Add;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
+ return Operator::Sub;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
+ return Operator::Mul;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
+ return Operator::Div;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
+ return Operator::Pow;
+}
+
+template <typename T, common::TypeCategory C, typename... Ts>
+Operator OperationCode(
+ const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
+ if constexpr (C == T::category) {
+ return Operator::Resize;
+ } else {
+ return Operator::Convert;
+ }
+}
+
+template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
+ return Operator::Constant;
+}
+
+template <typename T> Operator OperationCode(const T &) {
+ return Operator::Unknown;
+}
+
+Operator OperationCode(const evaluate::ProcedureDesignator &proc);
+
+} // namespace operation
+
+// Return information about the top-level operation (ignoring parentheses):
+// the operation code and the list of arguments.
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr);
+
+// Check if expr is same as x, or a sequence of Convert operations on x.
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
+
+// Strip away any top-level Convert operations (if any exist) and return
+// the input value. A ComplexConstructor(x, 0) is also considered as a
+// convert operation.
+// If the input is not Operation, Designator, FunctionRef or Constant,
+// it returns std::nullopt.
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
+
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index 69375a83dec25..f3cfa9b99fb4d 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
// Check for ambiguous USE associations
bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
-// Checks whether the symbol on the LHS is present in the RHS expression.
-bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
-
-namespace operation {
-
-enum class Operator {
- Unknown,
- Add,
- And,
- Associated,
- Call,
- Constant,
- Convert,
- Div,
- Eq,
- Eqv,
- False,
- Ge,
- Gt,
- Identity,
- Intrinsic,
- Le,
- Lt,
- Max,
- Min,
- Mul,
- Ne,
- Neqv,
- Not,
- Or,
- Pow,
- Resize, // Convert within the same TypeCategory
- Sub,
- True,
-};
-
-std::string ToString(Operator op);
-
-template <typename... Ts, int Kind>
-Operator OperationCode(
- const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
- switch (op.derived().logicalOperator) {
- case common::LogicalOperator::And:
- return Operator::And;
- case common::LogicalOperator::Or:
- return Operator::Or;
- case common::LogicalOperator::Eqv:
- return Operator::Eqv;
- case common::LogicalOperator::Neqv:
- return Operator::Neqv;
- case common::LogicalOperator::Not:
- return Operator::Not;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
- switch (op.derived().opr) {
- case common::RelationalOperator::LT:
- return Operator::Lt;
- case common::RelationalOperator::LE:
- return Operator::Le;
- case common::RelationalOperator::EQ:
- return Operator::Eq;
- case common::RelationalOperator::NE:
- return Operator::Ne;
- case common::RelationalOperator::GE:
- return Operator::Ge;
- case common::RelationalOperator::GT:
- return Operator::Gt;
- }
- return Operator::Unknown;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
- return Operator::Add;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
- return Operator::Sub;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
- return Operator::Mul;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
- return Operator::Div;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
- return Operator::Pow;
-}
-
-template <typename T, common::TypeCategory C, typename... Ts>
-Operator OperationCode(
- const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
- if constexpr (C == T::category) {
- return Operator::Resize;
- } else {
- return Operator::Convert;
- }
-}
-
-template <typename T> //
-Operator OperationCode(const evaluate::Constant<T> &x) {
- return Operator::Constant;
-}
-
-template <typename T> //
-Operator OperationCode(const T &) {
- return Operator::Unknown;
-}
-
-Operator OperationCode(const evaluate::ProcedureDesignator &proc);
-
-} // namespace operation
-
-/// Return information about the top-level operation (ignoring parentheses):
-/// the operation code and the list of arguments.
-std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
- const SomeExpr &expr);
-
-/// Check if expr is same as x, or a sequence of Convert operations on x.
-bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
-
-/// Strip away any top-level Convert operations (if any exist) and return
-/// the input value. A ComplexConstructor(x, 0) is also considered as a
-/// convert operation.
-/// If the input is not Operation, Designator, FunctionRef or Constant,
-/// it returns std::nullopt.
-MaybeExpr GetConvertInput(const SomeExpr &x);
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_TOOLS_H_
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 222c32a9c332e..68838564f87ba 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -13,6 +13,7 @@
#include "flang/Evaluate/traverse.h"
#include "flang/Parser/message.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSwitch.h"
#include <algorithm>
#include <variant>
@@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
}
}
+bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
+ if (lhs && rhs) {
+ if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
+ const Symbol &first{*lhsSymbols.front()};
+ for (const Symbol &symbol : GetSymbolVector(*rhs)) {
+ if (first == symbol) {
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+}
+
+namespace operation {
+template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
+ return AsGenericExpr(common::Clone(x));
+}
+
+template <bool IgnoreResizingConverts>
+struct ArgumentExtractor
+ : public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
+ std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
+ using Arguments = std::vector<Expr<SomeType>>;
+ using Result = std::pair<operation::Operator, Arguments>;
+ using Base =
+ Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
+ static constexpr auto IgnoreResizes{IgnoreResizingConverts};
+ static constexpr auto Logical{common::TypeCategory::Logical};
+ ArgumentExtractor() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <int Kind>
+ Result operator()(const Constant<Type<Logical, Kind>> &x) const {
+ if (const auto &val{x.GetScalarValue()}) {
+ return val->IsTrue()
+ ? std::make_pair(operation::Operator::True, Arguments{})
+ : std::make_pair(operation::Operator::False, Arguments{});
+ }
+ return Default();
+ }
+
+ template <typename R> Result operator()(const FunctionRef<R> &x) const {
+ Result result{operation::OperationCode(x.proc()), {}};
+ for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
+ if (auto *e{x.UnwrapArgExpr(i)}) {
+ result.second.push_back(*e);
+ }
+ }
+ return result;
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore top-level parentheses.
+ return (*this)(x.template operand<0>());
+ }
+ if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
+ // Ignore conversions within the same category.
+ // Atomic operations on int(kind=1) may be implicitly widened
+ // to int(kind=4) for example.
+ return (*this)(x.template operand<0>());
+ } else {
+ return std::make_pair(operation::OperationCode(x),
+ OperationArgs(x, std::index_sequence_for<Os...>{}));
+ }
+ }
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {operation::Operator::Identity, {AsSomeExpr(x)}};
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ // There shouldn't be any combining needed, since we're stopping the
+ // traversal at the top-level operation, but implement one that picks
+ // the first non-empty result.
+ if constexpr (sizeof...(Rs) == 0) {
+ return std::move(result);
+ } else {
+ if (!result.second.empty()) {
+ return std::move(result);
+ } else {
+ return Combine(std::move(results)...);
+ }
+ }
+ }
+
+private:
+ template <typename D, typename R, typename... Os, size_t... Is>
+ Arguments OperationArgs(
+ const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
+ return Arguments{Expr<SomeType>(x.template operand<Is>())...};
+ }
+};
+} // namespace operation
+
+std::string operation::ToString(operation::Operator op) {
+ switch (op) {
+ case Operator::Unknown:
+ return "??";
+ case Operator::Add:
+ return "+";
+ case Operator::And:
+ return "AND";
+ case Operator::Associated:
+ return "ASSOCIATED";
+ case Operator::Call:
+ return "function-call";
+ case Operator::Constant:
+ return "constant";
+ case Operator::Convert:
+ return "type-conversion";
+ case Operator::Div:
+ return "/";
+ case Operator::Eq:
+ return "==";
+ case Operator::Eqv:
+ return "EQV";
+ case Operator::False:
+ return ".FALSE.";
+ case Operator::Ge:
+ return ">=";
+ case Operator::Gt:
+ return ">";
+ case Operator::Identity:
+ return "identity";
+ case Operator::Intrinsic:
+ return "intrinsic";
+ case Operator::Le:
+ return "<=";
+ case Operator::Lt:
+ return "<";
+ case Operator::Max:
+ return "MAX";
+ case Operator::Min:
+ return "MIN";
+ case Operator::Mul:
+ return "*";
+ case Operator::Ne:
+ return "/=";
+ case Operator::Neqv:
+ return "NEQV/EOR";
+ case Operator::Not:
+ return "NOT";
+ case Operator::Or:
+ return "OR";
+ case Operator::Pow:
+ return "**";
+ case Operator::Resize:
+ return "resize";
+ case Operator::Sub:
+ return "-";
+ case Operator::True:
+ return ".TRUE.";
+ }
+ llvm_unreachable("Unhandler operator");
+}
+
+operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
+ Operator code{llvm::StringSwitch<Operator>(proc.GetName())
+ .Case("associated", Operator::Associated)
+ .Case("min", Operator::Min)
+ .Case("max", Operator::Max)
+ .Case("iand", Operator::And)
+ .Case("ior", Operator::Or)
+ .Case("ieor", Operator::Neqv)
+ .Default(Operator::Call)};
+ if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
+ return Operator::Intrinsic;
+ }
+ return code;
+}
+
+std::pair<operation::Operator, std::vector<Expr<SomeType>>>
+GetTopLevelOperation(const Expr<SomeType> &expr) {
+ return operation::ArgumentExtractor<true>{}(expr);
+}
+
+namespace operation {
+struct ConvertCollector
+ : public Traverse<ConvertCollector,
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
+ false> {
+ using Result =
+ std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
+ using Base = Traverse<ConvertCollector, Result, false>;
+ ConvertCollector() : Base(*this) {}
+
+ Result Default() const { return {}; }
+
+ using Base::operator();
+
+ template <typename T> Result operator()(const Designator<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const FunctionRef<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename T> Result operator()(const Constant<T> &x) const {
+ return {AsSomeExpr(x), {}};
+ }
+
+ template <typename D, typename R, typename... Os>
+ Result operator()(const Operation<D, R, Os...> &x) const {
+ if constexpr (std::is_same_v<D, Parentheses<R>>) {
+ // Ignore parentheses.
+ return (*this)(x.template operand<0>());
+ } else if constexpr (is_convert_v<D>) {
+ // Convert should always have a typed result, so it should be safe to
+ // dereference x.GetType().
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else if constexpr (is_complex_constructor_v<D>) {
+ // This is a conversion iff the imaginary operand is 0.
+ if (IsZero(x.template operand<1>())) {
+ return Combine(
+ {std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ } else {
+ return {AsSomeExpr(x.derived()), {}};
+ }
+ }
+
+ template <typename... Rs>
+ Result Combine(Result &&result, Rs &&...results) const {
+ Result v(std::move(result));
+ auto setValue{[](std::optional<Expr<SomeType>> &x,
+ std::optional<Expr<SomeType>> &&y) {
+ assert((!x.has_value() || !y.has_value()) && "Multiple designators");
+ if (!x.has_value()) {
+ x = std::move(y);
+ }
+ }};
+ auto moveAppend{[](auto &accum, auto &&other) {
+ for (auto &&s : other) {
+ accum.push_back(std::move(s));
+ }
+ }};
+ (setValue(v.first, std::move(results).first), ...);
+ (moveAppend(v.second, std::move(results).second), ...);
+ return v;
+ }
+
+private:
+ template <typename A> static bool IsZero(const A &x) { return false; }
+ template <typename T> static bool IsZero(const Expr<T> &x) {
+ return common::visit([](auto &&s) { return IsZero(s); }, x.u);
+ }
+ template <typename T> static bool IsZero(const Constant<T> &x) {
+ if (auto &&maybeScalar{x.GetScalarValue()}) {
+ return maybeScalar->IsZero();
+ } else {
+ return false;
+ }
+ }
+
+ template <typename T> struct is_convert {
+ static constexpr bool value{false};
+ };
+ template <typename T, common::TypeCategory C>
+ struct is_convert<Convert<T, C>> {
+ static constexpr bool value{true};
+ };
+ template <int K> struct is_convert<ComplexComponent<K>> {
+ // Conversion from complex to real.
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_convert_v{is_convert<T>::value};
+
+ template <typename T> struct is_complex_constructor {
+ static constexpr bool value{false};
+ };
+ template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
+ static constexpr bool value{true};
+ };
+ template <typename T>
+ static constexpr bool is_complex_constructor_v{
+ is_complex_constructor<T>::value};
+};
+} // namespace operation
+
+std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
+ // This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
+ return operation::ConvertCollector{}(x).first;
+}
+
+bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
+ // Check if expr is same as x, or a sequence of Convert operations on x.
+ if (expr == x) {
+ return true;
+ } else if (auto maybe{GetConvertInput(expr)}) {
+ return *maybe == x;
+ } else {
+ return false;
+ }
+}
} // namespace Fortran::evaluate
namespace Fortran::semantics {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 69e9c53baa740..3ef3330cba2d6 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
firOpBuilder.setInsertionPointToStart(&block);
if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
- if (Fortran::semantics::CheckForSymbolMatch(
+ if (Fortran::evaluate::CheckForSymbolMatch(
Fortran::semantics::GetExpr(stmt2Var),
Fortran::semantics::GetExpr(stmt2Expr))) {
// Atomic capture construct is of the form [capture-stmt, update-stmt]
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 82673f0948a5b..0acfd5b0a2534 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2840,11 +2840,12 @@ genAtomicUpdate(lower::AbstractConverter &converter, mlir::Location loc,
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must ex...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Some new code was added to flang/Semantics that only depends on facilities in flang/Evaluate. Move it into Evaluate and clean up some minor stylistic problems.