Skip to content

Commit b08b252

Browse files
kparzysztblah
andauthored
[flang][OpenMP] Semantic checks for DOACROSS clause (llvm#115397)
Keep track of loop constructs and OpenMP loop constructs that have been entered. Use the information to validate the variables in the SINK loop iteration vector. --------- Co-authored-by: Tom Eccles <tom.eccles@arm.com>
1 parent 3338186 commit b08b252

File tree

8 files changed

+237
-39
lines changed

8 files changed

+237
-39
lines changed

flang/lib/Lower/OpenMP/Clauses.cpp

+17-11
Original file line numberDiff line numberDiff line change
@@ -574,20 +574,17 @@ Defaultmap make(const parser::OmpClause::Defaultmap &inp,
574574
/*VariableCategory=*/maybeApply(convert2, t1)}};
575575
}
576576

577-
Depend make(const parser::OmpClause::Depend &inp,
578-
semantics::SemanticsContext &semaCtx) {
579-
// inp.v -> parser::OmpDependClause
580-
using wrapped = parser::OmpDependClause;
581-
using Variant = decltype(Depend::u);
577+
Doacross makeDoacross(const parser::OmpDoacross &doa,
578+
semantics::SemanticsContext &semaCtx) {
582579
// Iteration is the equivalent of parser::OmpIteration
583580
using Iteration = Doacross::Vector::value_type; // LoopIterationT
584581

585-
auto visitSource = [&](const parser::OmpDoacross::Source &) -> Variant {
582+
auto visitSource = [&](const parser::OmpDoacross::Source &) {
586583
return Doacross{{/*DependenceType=*/Doacross::DependenceType::Source,
587584
/*Vector=*/{}}};
588585
};
589586

590-
auto visitSink = [&](const parser::OmpDoacross::Sink &s) -> Variant {
587+
auto visitSink = [&](const parser::OmpDoacross::Sink &s) {
591588
using IterOffset = parser::OmpIterationOffset;
592589
auto convert2 = [&](const parser::OmpIteration &v) {
593590
auto &t0 = std::get<parser::Name>(v.t);
@@ -605,6 +602,15 @@ Depend make(const parser::OmpClause::Depend &inp,
605602
/*Vector=*/makeList(s.v.v, convert2)}};
606603
};
607604

605+
return common::visit(common::visitors{visitSink, visitSource}, doa.u);
606+
}
607+
608+
Depend make(const parser::OmpClause::Depend &inp,
609+
semantics::SemanticsContext &semaCtx) {
610+
// inp.v -> parser::OmpDependClause
611+
using wrapped = parser::OmpDependClause;
612+
using Variant = decltype(Depend::u);
613+
608614
auto visitTaskDep = [&](const wrapped::TaskDep &s) -> Variant {
609615
auto &t0 = std::get<std::optional<parser::OmpIteratorModifier>>(s.t);
610616
auto &t1 = std::get<parser::OmpTaskDependenceType>(s.t);
@@ -617,11 +623,11 @@ Depend make(const parser::OmpClause::Depend &inp,
617623
/*LocatorList=*/makeObjects(t2, semaCtx)}};
618624
};
619625

620-
return Depend{Fortran::common::visit( //
626+
return Depend{common::visit( //
621627
common::visitors{
622628
// Doacross
623629
[&](const parser::OmpDoacross &s) -> Variant {
624-
return common::visit(common::visitors{visitSink, visitSource}, s.u);
630+
return makeDoacross(s, semaCtx);
625631
},
626632
// Depend::TaskDep
627633
visitTaskDep,
@@ -692,8 +698,8 @@ DistSchedule make(const parser::OmpClause::DistSchedule &inp,
692698

693699
Doacross make(const parser::OmpClause::Doacross &inp,
694700
semantics::SemanticsContext &semaCtx) {
695-
// inp -> empty
696-
llvm_unreachable("Empty: doacross");
701+
// inp.v -> OmpDoacrossClause
702+
return makeDoacross(inp.v.v, semaCtx);
697703
}
698704

699705
// DynamicAllocators: empty

flang/lib/Semantics/check-omp-structure.cpp

+137-16
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPConstruct &) {
575575
}
576576

577577
void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
578+
loopStack_.push_back(&x);
578579
const auto &beginLoopDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
579580
const auto &beginDir{std::get<parser::OmpLoopDirective>(beginLoopDir.t)};
580581

@@ -968,11 +969,19 @@ void OmpStructureChecker::CheckDistLinear(
968969
}
969970
}
970971

971-
void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &) {
972+
void OmpStructureChecker::Leave(const parser::OpenMPLoopConstruct &x) {
972973
if (llvm::omp::allSimdSet.test(GetContext().directive)) {
973974
ExitDirectiveNest(SIMDNest);
974975
}
975976
dirContext_.pop_back();
977+
978+
assert(!loopStack_.empty() && "Expecting non-empty loop stack");
979+
const LoopConstruct &top{loopStack_.back()};
980+
#ifndef NDEBUG
981+
auto *loopc{std::get_if<const parser::OpenMPLoopConstruct *>(&top)};
982+
assert(loopc != nullptr && *loopc == &x && "Mismatched loop constructs");
983+
#endif
984+
loopStack_.pop_back();
976985
}
977986

978987
void OmpStructureChecker::Enter(const parser::OmpEndLoopDirective &x) {
@@ -1103,8 +1112,7 @@ void OmpStructureChecker::Leave(const parser::OpenMPBlockConstruct &) {
11031112
void OmpStructureChecker::ChecksOnOrderedAsBlock() {
11041113
if (FindClause(llvm::omp::Clause::OMPC_depend)) {
11051114
context_.Say(GetContext().clauseSource,
1106-
"DEPEND(*) clauses are not allowed when ORDERED construct is a block"
1107-
" construct with an ORDERED region"_err_en_US);
1115+
"DEPEND clauses are not allowed when ORDERED construct is a block construct with an ORDERED region"_err_en_US);
11081116
return;
11091117
}
11101118

@@ -1654,15 +1662,14 @@ void OmpStructureChecker::ChecksOnOrderedAsStandalone() {
16541662
if (FindClause(llvm::omp::Clause::OMPC_threads) ||
16551663
FindClause(llvm::omp::Clause::OMPC_simd)) {
16561664
context_.Say(GetContext().clauseSource,
1657-
"THREADS, SIMD clauses are not allowed when ORDERED construct is a "
1658-
"standalone construct with no ORDERED region"_err_en_US);
1665+
"THREADS and SIMD clauses are not allowed when ORDERED construct is a standalone construct with no ORDERED region"_err_en_US);
16591666
}
16601667

16611668
int dependSinkCount{0}, dependSourceCount{0};
16621669
bool exclusiveShown{false}, duplicateSourceShown{false};
16631670

1664-
auto visitDoacross = [&](const parser::OmpDoacross &doa,
1665-
const parser::CharBlock &src) {
1671+
auto visitDoacross{[&](const parser::OmpDoacross &doa,
1672+
const parser::CharBlock &src) {
16661673
common::visit(
16671674
common::visitors{
16681675
[&](const parser::OmpDoacross::Source &) { dependSourceCount++; },
@@ -1678,10 +1685,11 @@ void OmpStructureChecker::ChecksOnOrderedAsStandalone() {
16781685
context_.Say(src,
16791686
"At most one SOURCE dependence type can appear on the ORDERED directive"_err_en_US);
16801687
}
1681-
};
1688+
}};
16821689

1683-
auto clauseAll = FindClauses(llvm::omp::Clause::OMPC_depend);
1684-
for (auto itr = clauseAll.first; itr != clauseAll.second; ++itr) {
1690+
// Visit the DEPEND and DOACROSS clauses.
1691+
auto depClauses{FindClauses(llvm::omp::Clause::OMPC_depend)};
1692+
for (auto itr{depClauses.first}; itr != depClauses.second; ++itr) {
16851693
const auto &dependClause{
16861694
std::get<parser::OmpClause::Depend>(itr->second->u)};
16871695
if (auto *doAcross{std::get_if<parser::OmpDoacross>(&dependClause.v.u)}) {
@@ -1691,6 +1699,11 @@ void OmpStructureChecker::ChecksOnOrderedAsStandalone() {
16911699
"Only SINK or SOURCE dependence types are allowed when ORDERED construct is a standalone construct with no ORDERED region"_err_en_US);
16921700
}
16931701
}
1702+
auto doaClauses{FindClauses(llvm::omp::Clause::OMPC_doacross)};
1703+
for (auto itr{doaClauses.first}; itr != doaClauses.second; ++itr) {
1704+
auto &doaClause{std::get<parser::OmpClause::Doacross>(itr->second->u)};
1705+
visitDoacross(doaClause.v.v, itr->second->source);
1706+
}
16941707

16951708
bool isNestedInDoOrderedWithPara{false};
16961709
if (CurrentDirectiveIsNested() &&
@@ -1718,23 +1731,28 @@ void OmpStructureChecker::ChecksOnOrderedAsStandalone() {
17181731

17191732
void OmpStructureChecker::CheckOrderedDependClause(
17201733
std::optional<int64_t> orderedValue) {
1721-
auto visitDoacross = [&](const parser::OmpDoacross &doa,
1722-
const parser::CharBlock &src) {
1734+
auto visitDoacross{[&](const parser::OmpDoacross &doa,
1735+
const parser::CharBlock &src) {
17231736
if (auto *sinkVector{std::get_if<parser::OmpDoacross::Sink>(&doa.u)}) {
17241737
int64_t numVar = sinkVector->v.v.size();
17251738
if (orderedValue != numVar) {
17261739
context_.Say(src,
17271740
"The number of variables in the SINK iteration vector does not match the parameter specified in ORDERED clause"_err_en_US);
17281741
}
17291742
}
1730-
};
1731-
auto clauseAll{FindClauses(llvm::omp::Clause::OMPC_depend)};
1732-
for (auto itr = clauseAll.first; itr != clauseAll.second; ++itr) {
1743+
}};
1744+
auto depClauses{FindClauses(llvm::omp::Clause::OMPC_depend)};
1745+
for (auto itr{depClauses.first}; itr != depClauses.second; ++itr) {
17331746
auto &dependClause{std::get<parser::OmpClause::Depend>(itr->second->u)};
17341747
if (auto *doAcross{std::get_if<parser::OmpDoacross>(&dependClause.v.u)}) {
17351748
visitDoacross(*doAcross, itr->second->source);
17361749
}
17371750
}
1751+
auto doaClauses = FindClauses(llvm::omp::Clause::OMPC_doacross);
1752+
for (auto itr{doaClauses.first}; itr != doaClauses.second; ++itr) {
1753+
auto &doaClause{std::get<parser::OmpClause::Doacross>(itr->second->u)};
1754+
visitDoacross(doaClause.v.v, itr->second->source);
1755+
}
17381756
}
17391757

17401758
void OmpStructureChecker::CheckTargetUpdate() {
@@ -2712,7 +2730,6 @@ CHECK_SIMPLE_CLAUSE(Bind, OMPC_bind)
27122730
CHECK_SIMPLE_CLAUSE(Align, OMPC_align)
27132731
CHECK_SIMPLE_CLAUSE(Compare, OMPC_compare)
27142732
CHECK_SIMPLE_CLAUSE(CancellationConstructType, OMPC_cancellation_construct_type)
2715-
CHECK_SIMPLE_CLAUSE(Doacross, OMPC_doacross)
27162733
CHECK_SIMPLE_CLAUSE(OmpxAttribute, OMPC_ompx_attribute)
27172734
CHECK_SIMPLE_CLAUSE(OmpxBare, OMPC_ompx_bare)
27182735
CHECK_SIMPLE_CLAUSE(Fail, OMPC_fail)
@@ -3493,6 +3510,7 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
34933510
"Unexpected alternative in update clause");
34943511

34953512
if (doaDep) {
3513+
CheckDoacross(*doaDep);
34963514
CheckDependenceType(doaDep->GetDepType());
34973515
} else {
34983516
CheckTaskDependenceType(taskDep->GetTaskDepType());
@@ -3572,6 +3590,93 @@ void OmpStructureChecker::Enter(const parser::OmpClause::Depend &x) {
35723590
}
35733591
}
35743592

3593+
void OmpStructureChecker::Enter(const parser::OmpClause::Doacross &x) {
3594+
CheckAllowedClause(llvm::omp::Clause::OMPC_doacross);
3595+
CheckDoacross(x.v.v);
3596+
}
3597+
3598+
void OmpStructureChecker::CheckDoacross(const parser::OmpDoacross &doa) {
3599+
if (std::holds_alternative<parser::OmpDoacross::Source>(doa.u)) {
3600+
// Nothing to check here.
3601+
return;
3602+
}
3603+
3604+
// Process SINK dependence type. SINK may only appear in an ORDER construct,
3605+
// which references a prior ORDERED(n) clause on a DO or SIMD construct
3606+
// that marks the top of the loop nest.
3607+
3608+
auto &sink{std::get<parser::OmpDoacross::Sink>(doa.u)};
3609+
const std::list<parser::OmpIteration> &vec{sink.v.v};
3610+
3611+
// Check if the variables in the iteration vector are unique.
3612+
struct Less {
3613+
bool operator()(
3614+
const parser::OmpIteration *a, const parser::OmpIteration *b) const {
3615+
auto namea{std::get<parser::Name>(a->t)};
3616+
auto nameb{std::get<parser::Name>(b->t)};
3617+
assert(namea.symbol && nameb.symbol && "Unresolved symbols");
3618+
// The non-determinism of the "<" doesn't matter, we only care about
3619+
// equality, i.e. a == b <=> !(a < b) && !(b < a)
3620+
return reinterpret_cast<uintptr_t>(namea.symbol) <
3621+
reinterpret_cast<uintptr_t>(nameb.symbol);
3622+
}
3623+
};
3624+
if (auto *duplicate{FindDuplicateEntry<parser::OmpIteration, Less>(vec)}) {
3625+
auto name{std::get<parser::Name>(duplicate->t)};
3626+
context_.Say(name.source,
3627+
"Duplicate variable '%s' in the iteration vector"_err_en_US,
3628+
name.ToString());
3629+
}
3630+
3631+
// Check if the variables in the iteration vector are induction variables.
3632+
// Ignore any mismatch between the size of the iteration vector and the
3633+
// number of DO constructs on the stack. This is checked elsewhere.
3634+
3635+
auto GetLoopDirective{[](const parser::OpenMPLoopConstruct &x) {
3636+
auto &begin{std::get<parser::OmpBeginLoopDirective>(x.t)};
3637+
return std::get<parser::OmpLoopDirective>(begin.t).v;
3638+
}};
3639+
auto GetLoopClauses{[](const parser::OpenMPLoopConstruct &x)
3640+
-> const std::list<parser::OmpClause> & {
3641+
auto &begin{std::get<parser::OmpBeginLoopDirective>(x.t)};
3642+
return std::get<parser::OmpClauseList>(begin.t).v;
3643+
}};
3644+
3645+
std::set<const Symbol *> inductionVars;
3646+
for (const LoopConstruct &loop : llvm::reverse(loopStack_)) {
3647+
if (auto *doc{std::get_if<const parser::DoConstruct *>(&loop)}) {
3648+
// Do-construct, collect the induction variable.
3649+
if (auto &control{(*doc)->GetLoopControl()}) {
3650+
if (auto *b{std::get_if<parser::LoopControl::Bounds>(&control->u)}) {
3651+
inductionVars.insert(b->name.thing.symbol);
3652+
}
3653+
}
3654+
} else {
3655+
// Omp-loop-construct, check if it's do/simd with an ORDERED clause.
3656+
auto *loopc{std::get_if<const parser::OpenMPLoopConstruct *>(&loop)};
3657+
assert(loopc && "Expecting OpenMPLoopConstruct");
3658+
llvm::omp::Directive loopDir{GetLoopDirective(**loopc)};
3659+
if (loopDir == llvm::omp::OMPD_do || loopDir == llvm::omp::OMPD_simd) {
3660+
auto IsOrdered{[](const parser::OmpClause &c) {
3661+
return c.Id() == llvm::omp::OMPC_ordered;
3662+
}};
3663+
// If it has ORDERED clause, stop the traversal.
3664+
if (llvm::any_of(GetLoopClauses(**loopc), IsOrdered)) {
3665+
break;
3666+
}
3667+
}
3668+
}
3669+
}
3670+
for (const parser::OmpIteration &iter : vec) {
3671+
auto &name{std::get<parser::Name>(iter.t)};
3672+
if (!inductionVars.count(name.symbol)) {
3673+
context_.Say(name.source,
3674+
"The iteration vector element '%s' is not an induction variable within the ORDERED loop nest"_err_en_US,
3675+
name.ToString());
3676+
}
3677+
}
3678+
}
3679+
35753680
void OmpStructureChecker::CheckCopyingPolymorphicAllocatable(
35763681
SymbolSourceMap &symbols, const llvm::omp::Clause clause) {
35773682
if (context_.ShouldWarn(common::UsageWarning::Portability)) {
@@ -4326,6 +4431,22 @@ void OmpStructureChecker::Enter(
43264431
CheckAllowedRequiresClause(llvm::omp::Clause::OMPC_unified_shared_memory);
43274432
}
43284433

4434+
void OmpStructureChecker::Enter(const parser::DoConstruct &x) {
4435+
Base::Enter(x);
4436+
loopStack_.push_back(&x);
4437+
}
4438+
4439+
void OmpStructureChecker::Leave(const parser::DoConstruct &x) {
4440+
assert(!loopStack_.empty() && "Expecting non-empty loop stack");
4441+
const LoopConstruct &top = loopStack_.back();
4442+
#ifndef NDEBUG
4443+
auto *doc{std::get_if<const parser::DoConstruct *>(&top)};
4444+
assert(doc != nullptr && *doc == &x && "Mismatched loop constructs");
4445+
#endif
4446+
loopStack_.pop_back();
4447+
Base::Leave(x);
4448+
}
4449+
43294450
void OmpStructureChecker::CheckAllowedRequiresClause(llvmOmpClause clause) {
43304451
CheckAllowedClause(clause);
43314452

flang/lib/Semantics/check-omp-structure.h

+19-6
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class OmpStructureChecker
6060
: public DirectiveStructureChecker<llvm::omp::Directive, llvm::omp::Clause,
6161
parser::OmpClause, llvm::omp::Clause_enumSize> {
6262
public:
63+
using Base = DirectiveStructureChecker<llvm::omp::Directive,
64+
llvm::omp::Clause, parser::OmpClause, llvm::omp::Clause_enumSize>;
65+
6366
OmpStructureChecker(SemanticsContext &context)
6467
: DirectiveStructureChecker(context,
6568
#define GEN_FLANG_DIRECTIVE_CLAUSE_MAP
@@ -131,6 +134,9 @@ class OmpStructureChecker
131134
void Enter(const parser::OmpAtomicCapture &);
132135
void Leave(const parser::OmpAtomic &);
133136

137+
void Enter(const parser::DoConstruct &);
138+
void Leave(const parser::DoConstruct &);
139+
134140
#define GEN_FLANG_CLAUSE_CHECK_ENTER
135141
#include "llvm/Frontend/OpenMP/OMP.inc"
136142

@@ -157,13 +163,19 @@ class OmpStructureChecker
157163
const parser::OmpScheduleModifierType::ModType &);
158164
void CheckAllowedMapTypes(const parser::OmpMapClause::Type &,
159165
const std::list<parser::OmpMapClause::Type> &);
160-
template <typename T> const T *FindDuplicateEntry(const std::list<T> &);
161166
llvm::StringRef getClauseName(llvm::omp::Clause clause) override;
162167
llvm::StringRef getDirectiveName(llvm::omp::Directive directive) override;
163168

169+
template <typename T> struct DefaultLess {
170+
bool operator()(const T *a, const T *b) const { return *a < *b; }
171+
};
172+
template <typename T, typename Less = DefaultLess<T>>
173+
const T *FindDuplicateEntry(const std::list<T> &);
174+
164175
void CheckDependList(const parser::DataRef &);
165176
void CheckDependArraySection(
166177
const common::Indirection<parser::ArrayElement> &, const parser::Name &);
178+
void CheckDoacross(const parser::OmpDoacross &doa);
167179
bool IsDataRefTypeParamInquiry(const parser::DataRef *dataRef);
168180
void CheckIsVarPartOfAnotherVar(const parser::CharBlock &source,
169181
const parser::OmpObjectList &objList, llvm::StringRef clause = "");
@@ -255,20 +267,21 @@ class OmpStructureChecker
255267
int directiveNest_[LastType + 1] = {0};
256268

257269
SymbolSourceMap deferredNonVariables_;
270+
271+
using LoopConstruct = std::variant<const parser::DoConstruct *,
272+
const parser::OpenMPLoopConstruct *>;
273+
std::vector<LoopConstruct> loopStack_;
258274
};
259275

260-
template <typename T>
276+
template <typename T, typename Less>
261277
const T *OmpStructureChecker::FindDuplicateEntry(const std::list<T> &list) {
262278
// Add elements of the list to a set. If the insertion fails, return
263279
// the address of the failing element.
264280

265281
// The objects of type T may not be copyable, so add their addresses
266282
// to the set. The set will need to compare the actual objects, so
267283
// the custom comparator is provided.
268-
struct less {
269-
bool operator()(const T *a, const T *b) const { return *a < *b; }
270-
};
271-
std::set<const T *, less> uniq;
284+
std::set<const T *, Less> uniq;
272285

273286
for (const T &item : list) {
274287
if (!uniq.insert(&item).second) {

0 commit comments

Comments
 (0)