Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][Lower] Emit exiting branches from within constructs #92455

Merged
merged 2 commits into from
May 21, 2024

Conversation

kparzysz
Copy link
Contributor

When lowering IfConstruct, CaseConstruct, and SelectTypeConstruct, emit branches that exit the construct in each block that is still unterminated after the FIR has been generated in it.

The same thing may be needed for SelectRankConstruct, once it's supported.

This eliminates the need for inserting branches in genFIR(Evaluation).

Follow-up to PR #91614.

When lowering IfConstruct, CaseConstruct, and SelectTypeConstruct,
emit branches that exit the construct in each block that is still
unterminated after the FIR has been generated in it.

The same thing may be needed for SelectRankConstruct, once it's
supported.

This eliminates the need for inserting branches in `genFIR(Evaluation)`.

Follow-up to PR llvm#91614.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels May 16, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 16, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Krzysztof Parzyszek (kparzysz)

Changes

When lowering IfConstruct, CaseConstruct, and SelectTypeConstruct, emit branches that exit the construct in each block that is still unterminated after the FIR has been generated in it.

The same thing may be needed for SelectRankConstruct, once it's supported.

This eliminates the need for inserting branches in genFIR(Evaluation).

Follow-up to PR #91614.


Full diff: https://github.com/llvm/llvm-project/pull/92455.diff

2 Files Affected:

  • (modified) flang/lib/Lower/Bridge.cpp (+36-22)
  • (modified) flang/test/Lower/branching-directive.f90 (+70-7)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index afbc1122de868..9ff00c6458aa2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1300,6 +1300,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     genBranch(targetEval.block);
   }
 
+  /// A construct contains nested evaluations. Some of these evaluations
+  /// may start a new basic block, others will add code to an existing
+  /// block.
+  /// Collect the list of nested evaluations that are last in their block.
+  /// These evaluations may need a branch exiting from their parent construct.
+  void collectFinalEvaluations(
+      Fortran::lower::pft::Evaluation &construct,
+      llvm::SmallVector<Fortran::lower::pft::Evaluation *> &finals) {
+    Fortran::lower::pft::Evaluation *previous = nullptr;
+    Fortran::lower::pft::Evaluation *exit = construct.constructExit;
+    for (auto &nested : construct.getNestedEvaluations()) {
+      if (nested.block != nullptr && previous != nullptr && previous != exit)
+        finals.push_back(previous);
+      previous = &nested;
+    }
+    if (previous != exit)
+      finals.push_back(previous);
+  }
+
   /// Generate a SelectOp or branch sequence that compares \p selector against
   /// values in \p valueList and targets corresponding labels in \p labelList.
   /// If no value matches the selector, branch to \p defaultEval.
@@ -2107,6 +2126,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
 
     // Unstructured branch sequence.
+    llvm::SmallVector<Fortran::lower::pft::Evaluation *> finals;
+    collectFinalEvaluations(eval, finals);
+
     for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
       auto genIfBranch = [&](mlir::Value cond) {
         if (e.lexicalSuccessor == e.controlSuccessor) // empty block -> exit
@@ -2127,6 +2149,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         genIfBranch(genIfCondition(s));
       } else {
         genFIR(e);
+        if (blockIsUnterminated() && llvm::is_contained(finals, &e))
+          genConstructExitBranch(*eval.constructExit);
       }
     }
   }
@@ -2135,11 +2159,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     Fortran::lower::pft::Evaluation &eval = getEval();
     Fortran::lower::StatementContext stmtCtx;
     pushActiveConstruct(eval, stmtCtx);
+
+    llvm::SmallVector<Fortran::lower::pft::Evaluation *> finals;
+    collectFinalEvaluations(eval, finals);
+
     for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
       if (e.getIf<Fortran::parser::EndSelectStmt>())
         maybeStartBlock(e.block);
       else
         genFIR(e);
+      if (blockIsUnterminated() && llvm::is_contained(finals, &e))
+        genConstructExitBranch(*eval.constructExit);
     }
     popActiveConstruct();
   }
@@ -3005,6 +3035,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     }
 
     pushActiveConstruct(getEval(), stmtCtx);
+    llvm::SmallVector<Fortran::lower::pft::Evaluation *> finals;
+    collectFinalEvaluations(getEval(), finals);
+    Fortran::lower::pft::Evaluation &constructExit = *getEval().constructExit;
+
     for (Fortran::lower::pft::Evaluation &eval :
          getEval().getNestedEvaluations()) {
       setCurrentPosition(eval.position);
@@ -3201,6 +3235,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       } else {
         genFIR(eval);
       }
+      if (blockIsUnterminated() && llvm::is_contained(finals, &eval))
+        genConstructExitBranch(constructExit);
     }
     popActiveConstruct();
   }
@@ -4535,28 +4571,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     setCurrentEval(eval);
     setCurrentPosition(eval.position);
     eval.visit([&](const auto &stmt) { genFIR(stmt); });
-
-    // Generate an end-of-block branch for several special cases. For
-    // constructs, this can be done for either the end construct statement,
-    // or for the construct itself, which will skip this code if the
-    // end statement was visited first and generated a branch.
-    Fortran::lower::pft::Evaluation *successor = [&]() {
-      if (eval.isConstruct() ||
-          (eval.isDirective() && eval.hasNestedEvaluations()))
-        return eval.getLastNestedEvaluation().lexicalSuccessor;
-      return eval.lexicalSuccessor;
-    }();
-
-    if (successor && blockIsUnterminated()) {
-      if (successor->isIntermediateConstructStmt() &&
-          successor->parentConstruct->lowerAsUnstructured())
-        // Exit from an intermediate unstructured IF or SELECT construct block.
-        genBranch(successor->parentConstruct->constructExit->block);
-      else if (unstructuredContext && eval.isConstructStmt() &&
-               successor == eval.controlSuccessor)
-        // Exit from a degenerate, empty construct block.
-        genBranch(eval.parentConstruct->constructExit->block);
-    }
   }
 
   /// Map mlir function block arguments to the corresponding Fortran dummy
diff --git a/flang/test/Lower/branching-directive.f90 b/flang/test/Lower/branching-directive.f90
index a0a147f1053a4..69270d7bcbe96 100644
--- a/flang/test/Lower/branching-directive.f90
+++ b/flang/test/Lower/branching-directive.f90
@@ -1,25 +1,88 @@
-!RUN: flang-new -fc1 -emit-hlfir -fopenmp -o - %s | FileCheck %s
+!RUN: bbc -emit-hlfir -fopenacc -fopenmp -o - %s | FileCheck %s
 
 !https://github.com/llvm/llvm-project/issues/91526
 
+!CHECK-LABEL: func.func @_QPsimple1
 !CHECK:   cf.cond_br %{{[0-9]+}}, ^bb[[THEN:[0-9]+]], ^bb[[ELSE:[0-9]+]]
 !CHECK: ^bb[[THEN]]:
-!CHECK:   cf.br ^bb[[EXIT:[0-9]+]]
+!CHECK:   omp.parallel
+!CHECK:   cf.br ^bb[[ENDIF:[0-9]+]]
 !CHECK: ^bb[[ELSE]]:
 !CHECK:   fir.call @_FortranAStopStatement
 !CHECK:   fir.unreachable
-!CHECK: ^bb[[EXIT]]:
+!CHECK: ^bb[[ENDIF]]:
+!CHECK:   return
 
-subroutine simple(y)
+subroutine simple1(y)
   implicit none
   logical, intent(in) :: y
   integer :: i
   if (y) then
-!$omp parallel
+    !$omp parallel
     i = 1
-!$omp end parallel
+    !$omp end parallel
   else
     stop 1
   end if
-end subroutine simple
+end subroutine
+
+!CHECK-LABEL: func.func @_QPsimple2
+!CHECK:   cf.cond_br %{{[0-9]+}}, ^bb[[THEN:[0-9]+]], ^bb[[ELSE:[0-9]+]]
+!CHECK: ^bb[[THEN]]:
+!CHECK:   omp.parallel
+!CHECK:   cf.br ^bb[[ENDIF:[0-9]+]]
+!CHECK: ^bb[[ELSE]]:
+!CHECK:   fir.call @_FortranAStopStatement
+!CHECK:   fir.unreachable
+!CHECK: ^bb[[ENDIF]]:
+!CHECK:   fir.call @_FortranAioOutputReal64
+!CHECK:   return
+subroutine simple2(x, yn)
+  implicit none
+  logical, intent(in) :: yn
+  integer, intent(in) :: x
+  integer :: i
+  real(8) :: E
+  E = 0d0
+
+  if (yn) then
+     !$omp parallel do private(i) reduction(+:E)
+     do i = 1, x
+        E = E + i
+     end do
+     !$omp end parallel do
+  else
+     stop 1
+  end if
+  print *, E
+end subroutine
+
+!CHECK-LABEL: func.func @_QPacccase
+!CHECK: fir.select_case %{{[0-9]+}} : i32 [{{.*}}, ^bb[[CASE1:[0-9]+]], {{.*}}, ^bb[[CASE2:[0-9]+]], {{.*}}, ^bb[[CASE3:[0-9]+]]]
+!CHECK: ^bb[[CASE1]]:
+!CHECK:   acc.serial
+!CHECK:   cf.br ^bb[[EXIT:[0-9]+]]
+!CHECK: ^bb[[CASE2]]:
+!CHECK:   fir.call @_FortranAioOutputAscii
+!CHECK:   cf.br ^bb[[EXIT]]
+!CHECK: ^bb[[CASE3]]:
+!CHECK:   fir.call @_FortranAioOutputAscii
+!CHECK:   cf.br ^bb[[EXIT]]
+!CHECK: ^bb[[EXIT]]:
+!CHECK:   return
+subroutine acccase(var)
+  integer :: var
+  integer :: res(10)
+  select case (var)
+    case (1)
+      print *, "case 1"
+      !$acc serial
+      res(1) = 1
+      !$acc end serial
+    case (2)
+      print *, "case 2"
+    case default
+      print *, "case default"
+  end select
+end subroutine
 

@vdonaldson
Copy link
Contributor

vdonaldson commented May 17, 2024

@kparzysz Thanks for working on this issue.

  1. I'm seeing some failures with this change, such as for the test case below.
  2. Are you aware of the isIntermediateConstructStmt Evaluation utility? That might make the extra collectFinalEvaluations traversal unnecessary. That would also provide a way to treat the last block of a sequence different than intermediate blocks, if that contributes to the problem below.
  3. You are deleting not only the intermediate unstructured branch code, but also the degenerate construct block code. Accounting for degenerate code blocks is one complication with this functionality, although I don't immediately see a problem with that in this PR. Transitions between, and various nesting instances of structured and unstructured code sequences are another complication.
  4. New line 3040 looks like it may be dead code.

[Thanks for the new test]

Test - Expected output 1111, not 1101.

subroutine s(j, k)
    goto (11, 22, 33) j-3  ! computed goto - an expression outside [1,3] is a nop
    if (j == 2) goto 22
    if (j == 1) goto 11
    k = k + 1
11  k = k + 10
22  k = k + 100
33  k = k + 1000
end

program p
  integer :: n = 0
  call s(0, n)
  print*, n
end

@Leporacanthicus
Copy link
Contributor

Nice, this solves the issue I was having with OpenMP and if something then <openmp loop here> else stop.

Thanks for the work on this.

@kparzysz
Copy link
Contributor Author

  1. Do you have a test that illustrates the problem? The 91526 test case looks ok after your prior PR.

That PR had two testcases in the comments that still failed, which made me think that the problem has more to do with how we lower constructs, not specifically with OpenMP/OpenACC directives. The goal here was to replace that patch with something that's better grounded in self-evident principles.

Thanks for the pointer about isIntermediateConstructStmt, and the new testcase.

The "final" (block-ending) evaluations are now separwted into two
categories: those that exit the construct, and those that were
"interrupted" by a branch into the construct. The latter ones need
to continue to their lexical successor.
@kparzysz
Copy link
Contributor Author

The previous fix failed for unstructured flow because it assumed that every evaluation that ends a block also exits the construct.
The updated code recognizes evaluations that end a block (within a construct) because their lexical successor is a target of some branch.

However, I think this can be fixed in a more elegant way. Below is the PFT for the most recent problematic subroutine:

1 Subroutine s: subroutine s(j, k)
  1 ComputedGotoStmt! -> 9: goto(11, 22, 33) j-3
  <<IfConstruct!>> -> 10
    2 ^IfStmt [negate] -> 10: if(j == 2) goto 22
    <<IfConstruct>> -> 9
      5 ^IfStmt [negate] -> 9: if(j == 1) goto 11
      8 ^AssignmentStmt: k = k + 1
      7 EndIfStmt
    <<End IfConstruct>>
    9 ^AssignmentStmt!: 11 k = k + 10
    4 EndIfStmt
  <<End IfConstruct!>>
  10 ^AssignmentStmt: 22 k = k + 100
  11 ^AssignmentStmt: 33 k = k + 1000
  12 EndSubroutineStmt: end
End Subroutine s

The outermost IfConstruct has a branch going from outside of it (the computed goto), to the assignment at index 9. If this was a Fortran program, this would not be allowed, and this additional patch would not have been necessary.

I think that a better fix would be to avoid creating such constructs in the PFT, if they didn't exist in the source. What do you think?

@vdonaldson
Copy link
Contributor

I believe the issue here is that individual directives sometimes act like action statements, sometimes like construct statements, and sometimes like constructs. The "simple2" test case can be resolved by treating the OpenMPConstruct directive like a construct, and the OmpEndLoopDirective directive like a construct statement, possibly with the following change.

One alternative to the first diff group below would be to instead add OmpEndLoopDirective to the isExecutableDirective group. Other variants might also be possible.

There may or may not be additional directive cases that should also be managed with similar changes. OpenMP and OpenACC directives are already given special treatment in various places. I think that special treatment should be extended to handle problems such as this one.

diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index f196b9c5a0cb..e11329010cb9 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -467,7 +467,9 @@ private:
     evaluationListStack.back()->emplace_back(std::move(eval));
     lower::pft::Evaluation *p = &evaluationListStack.back()->back();
     if (p->isActionStmt() || p->isConstructStmt() || p->isEndStmt() ||
-        p->isExecutableDirective()) {
+        p->isExecutableDirective() ||
+        p->isA<Fortran::parser::OmpEndLoopDirective>()) {
+
       if (lastLexicalEvaluation) {
         lastLexicalEvaluation->lexicalSuccessor = p;
         p->printIndex = lastLexicalEvaluation->printIndex + 1;
@@ -1024,6 +1026,9 @@ private:
           },
           [&](const parser::WhereConstruct &) { setConstructExit(eval); },

+          // Directives - set (unstructured) directive exit targets
+          [&](const parser::OpenMPConstruct &) { setConstructExit(eval); },
+
           // Default - Common analysis for IO statements; otherwise nop.
           [&](const auto &stmt) {
             using A = std::decay_t<decltype(stmt)>;

@kparzysz
Copy link
Contributor Author

The "simple2" test case can be resolved by treating the OpenMPConstruct directive like a construct, and the OmpEndLoopDirective directive like a construct statement, possibly with the following change.

This change alone causes 31 tests to fail in check-flang.

@vdonaldson
Copy link
Contributor

@kparzysz - You are right. Thanks to the evolution of directive processing, I'm not having much success extending the existing processing to handle these cases. I don't find any test failures with your fix. Please go ahead with it. Thanks for working on this!

@kparzysz kparzysz merged commit c1b5b7c into llvm:main May 21, 2024
4 checks passed
@kparzysz kparzysz deleted the users/kparzysz/exit-branch branch May 21, 2024 13:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants