Skip to content

Commit

Permalink
ParModelica fixes for the new frontend (#8260)
Browse files Browse the repository at this point in the history
- Handle parfor statements.
- Add dumping of parfor statements in DAEDumpTpl.
- Declare builtin ParModelica functions as impure so the NF doesn't try
  to evaluate them.
  • Loading branch information
perost committed Dec 3, 2021
1 parent dbc8d8b commit d9142e1
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 27 deletions.
31 changes: 28 additions & 3 deletions OMCompiler/Compiler/NFFrontEnd/NFConvertDAE.mo
Expand Up @@ -921,15 +921,40 @@ protected
list<Statement> body;
list<DAE.Statement> dbody;
DAE.ElementSource source;
Statement.ForType for_type;
list<tuple<DAE.ComponentRef, SourceInfo>> loop_vars;
algorithm
Statement.FOR(iterator = iterator, range = SOME(range), body = body, source = source) := forStmt;
Statement.FOR(iterator = iterator, range = SOME(range), body = body, forType = for_type, source = source) := forStmt;
dbody := convertStatements(body);
Component.ITERATOR(ty = ty) := InstNode.component(iterator);

forDAE := DAE.Statement.STMT_FOR(Type.toDAE(ty), Type.isArray(ty),
InstNode.name(iterator), 0, Expression.toDAE(range), dbody, source);
forDAE := match for_type
case Statement.ForType.NORMAL()
then DAE.Statement.STMT_FOR(Type.toDAE(ty), Type.isArray(ty),
InstNode.name(iterator), 0, Expression.toDAE(range), dbody, source);

case Statement.ForType.PARALLEL()
algorithm
loop_vars := list(convertForStatementParallelVar(v) for v in for_type.vars);
then
DAE.Statement.STMT_PARFOR(Type.toDAE(ty), Type.isArray(ty),
InstNode.name(iterator), 0, Expression.toDAE(range), dbody, loop_vars, source);
end match;
end convertForStatement;

function convertForStatementParallelVar
input tuple<ComponentRef, SourceInfo> var;
output tuple<DAE.ComponentRef, SourceInfo> outVar;
protected
ComponentRef cref;
DAE.ComponentRef dcref;
SourceInfo info;
algorithm
(cref, info) := var;
dcref := ComponentRef.toDAE(cref);
outVar := (dcref, info);
end convertForStatementParallelVar;

function convertIfStatement
input list<tuple<Expression, list<Statement>>> ifBranches;
input DAE.ElementSource source;
Expand Down
95 changes: 92 additions & 3 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Expand Up @@ -77,7 +77,7 @@ import Face = NFConnector.Face;
import System;
import ComplexType = NFComplexType;
import NFInstNode.CachedData;
import NFPrefixes.{Direction, Variability, Visibility};
import NFPrefixes.{Direction, Variability, Visibility, Parallelism};
import Variable = NFVariable;
import ElementSource;
import Ceval = NFCeval;
Expand Down Expand Up @@ -829,12 +829,12 @@ algorithm

iter :: iters := iters;
range :: ranges := ranges;
stmt := Statement.FOR(iter, SOME(range), body, alg.source);
stmt := Statement.FOR(iter, SOME(range), body, Statement.ForType.NORMAL(), alg.source);

while not listEmpty(iters) loop
iter :: iters := iters;
range :: ranges := ranges;
stmt := Statement.FOR(iter, SOME(range), body, alg.source);
stmt := Statement.FOR(iter, SOME(range), body, Statement.ForType.NORMAL(), alg.source);
end while;
then
Algorithm.ALGORITHM({stmt}, alg.source);
Expand Down Expand Up @@ -1534,6 +1534,7 @@ algorithm
algorithm
stmt.range := Util.applyOption(stmt.range, function flattenExp(prefix = prefix));
stmt.body := flattenStatements(stmt.body, prefix);
stmt.forType := updateForType(stmt.forType, stmt.body);
then
stmt;

Expand Down Expand Up @@ -2184,5 +2185,93 @@ algorithm
end match;
end collectClassFunctions;

function updateForType
input output Statement.ForType forType;
input list<Statement> forBody;
protected
UnorderedMap<ComponentRef, SourceInfo> vars;
algorithm
() := match forType
case Statement.ForType.NORMAL() then ();

case Statement.ForType.PARALLEL()
algorithm
// ParModelica needs to know which variables are used in the loop body,
// so collect them here and add them to the ForType.
vars := UnorderedMap.new<SourceInfo>(ComponentRef.hash, ComponentRef.isEqual);

for s in forBody loop
vars := Statement.fold(s, collectParallelVariables, vars);
end for;

forType.vars := UnorderedMap.toList(vars);

// Only parglobal variables are allowed to be used in a parfor loop.
for v in forType.vars loop
checkParGlobalCref(v);
end for;
then
();

end match;
end updateForType;

function collectParallelVariables
input Statement stmt;
input output UnorderedMap<ComponentRef, SourceInfo> vars;
protected
SourceInfo info;
algorithm
info := Statement.info(stmt);
vars := Statement.foldExp(stmt,
function Expression.fold(func = function collectParallelVariablesExp(info = info)), vars);
end collectParallelVariables;

function collectParallelVariablesExp
input Expression exp;
input SourceInfo info;
input output UnorderedMap<ComponentRef, SourceInfo> vars;
protected
InstNode node;
ComponentRef cref;
algorithm
() := match exp
case Expression.CREF()
guard ComponentRef.isCref(exp.cref) and
not ComponentRef.isIterator(exp.cref) and
InstNode.isComponent(ComponentRef.node(exp.cref))
algorithm
cref := ComponentRef.stripSubscriptsAll(exp.cref);
UnorderedMap.tryAdd(cref, info, vars);
then
();

else ();
end match;
end collectParallelVariablesExp;

function checkParGlobalCref
input tuple<ComponentRef, SourceInfo> crefInfo;
protected
ComponentRef cref;
SourceInfo info;
InstNode node;
String errorString;
algorithm
(cref, info) := crefInfo;
node := ComponentRef.node(cref);

if Component.parallelism(InstNode.component(node)) <> Parallelism.GLOBAL then
errorString := "\n" +
"- Component '" + AbsynUtil.pathString(ComponentRef.toPath(cref)) +
"' is used in a parallel for loop." + "\n" +
"- Parallel for loops can only contain references to parglobal variables"
;
Error.addSourceMessage(Error.PARMODELICA_ERROR,
{errorString}, info);
fail();
end if;
end checkParGlobalCref;

annotation(__OpenModelica_Interface="frontend");
end NFFlatten;
11 changes: 10 additions & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFInst.mo
Expand Up @@ -3243,7 +3243,16 @@ algorithm
next_origin := InstContext.set(context, NFInstContext.FOR);
stmtl := instStatements(scodeStmt.forBody, for_scope, next_origin);
then
Statement.FOR(iter, oexp, stmtl, makeSource(scodeStmt.comment, info));
Statement.FOR(iter, oexp, stmtl, Statement.ForType.NORMAL(), makeSource(scodeStmt.comment, info));

case SCode.Statement.ALG_PARFOR(info = info)
algorithm
oexp := instExpOpt(scodeStmt.range, scope, context, info);
(for_scope, iter) := addIteratorToScope(scodeStmt.index, scope, info);
next_origin := InstContext.set(context, NFInstContext.FOR);
stmtl := instStatements(scodeStmt.parforBody, for_scope, next_origin);
then
Statement.FOR(iter, oexp, stmtl, Statement.ForType.PARALLEL({}), makeSource(scodeStmt.comment, info));

case SCode.Statement.ALG_IF(info = info)
algorithm
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFScalarize.mo
Expand Up @@ -360,7 +360,7 @@ function scalarizeStatement
algorithm
statements := match stmt
case Statement.FOR()
then Statement.FOR(stmt.iterator, stmt.range, scalarizeStatements(stmt.body), stmt.source) :: statements;
then Statement.FOR(stmt.iterator, stmt.range, scalarizeStatements(stmt.body), stmt.forType, stmt.source) :: statements;

case Statement.IF()
then scalarizeIfStatement(stmt.branches, stmt.source, statements);
Expand Down
64 changes: 64 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFStatement.mo
Expand Up @@ -34,6 +34,7 @@ encapsulated uniontype NFStatement
import Expression = NFExpression;
import NFInstNode.InstNode;
import DAE;
import ComponentRef = NFComponentRef;

protected
import Statement = NFStatement;
Expand All @@ -42,6 +43,15 @@ protected
import Util;
import IOStream;

public
uniontype ForType
record NORMAL end NORMAL;

record PARALLEL
list<tuple<ComponentRef, SourceInfo>> vars;
end PARALLEL;
end ForType;

public
record ASSIGNMENT
Expression lhs "The asignee";
Expand All @@ -60,6 +70,7 @@ public
InstNode iterator;
Option<Expression> range;
list<Statement> body "The body of the for loop.";
ForType forType;
DAE.ElementSource source;
end FOR;

Expand Down Expand Up @@ -258,6 +269,59 @@ public
stmt := func(stmt);
end map;

function fold<ArgT>
input Statement stmt;
input MapFn func;
input output ArgT arg;

partial function MapFn
input Statement stmt;
input output ArgT arg;
end MapFn;
algorithm
() := match stmt
case FOR()
algorithm
for s in stmt.body loop
arg := fold(s, func, arg);
end for;
then
();

case IF()
algorithm
for b in stmt.branches loop
for s in Util.tuple22(b) loop
arg := fold(s, func, arg);
end for;
end for;
then
();

case WHEN()
algorithm
for b in stmt.branches loop
for s in Util.tuple22(b) loop
arg := fold(s, func, arg);
end for;
end for;
then
();

case WHILE()
algorithm
for s in stmt.body loop
arg := fold(s, func, arg);
end for;
then
();

else ();
end match;

arg := func(stmt, arg);
end fold;

function applyExpList
input list<Statement> stmt;
input FoldFunc func;
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFTyping.mo
Expand Up @@ -3107,7 +3107,7 @@ algorithm
next_context := InstContext.set(context, NFInstContext.FOR);
body := typeStatements(st.body, next_context);
then
Statement.FOR(st.iterator, SOME(e1), body, st.source);
Statement.FOR(st.iterator, SOME(e1), body, st.forType, st.source);

case Statement.IF()
algorithm
Expand Down
8 changes: 8 additions & 0 deletions OMCompiler/Compiler/Template/DAEDumpTV.mo
Expand Up @@ -527,6 +527,14 @@ package DAE
ElementSource source;
end STMT_FOR;

record STMT_PARFOR
Boolean iterIsArray;
Ident iter;
Exp range;
list<Statement> statementLst;
ElementSource source;
end STMT_PARFOR;

record STMT_WHILE
Exp exp;
list<Statement> statementLst;
Expand Down
15 changes: 15 additions & 0 deletions OMCompiler/Compiler/Template/DAEDumpTpl.tpl
Expand Up @@ -866,6 +866,7 @@ match stmt
case STMT_ASSIGN_ARR(__) then dumpArrayAssignStatement(stmt)
case STMT_IF(__) then dumpIfStatement(stmt)
case STMT_FOR(__) then dumpForStatement(stmt)
case STMT_PARFOR(__) then dumpParForStatement(stmt)
case STMT_WHILE(__) then dumpWhileStatement(stmt)
case STMT_WHEN(__) then dumpWhenStatement(stmt)
case STMT_ASSERT(__) then dumpAssert(cond, msg, level, source)
Expand Down Expand Up @@ -964,6 +965,20 @@ match stmt
>>
end dumpForStatement;

template dumpParForStatement(DAE.Statement stmt)
::=
match stmt
case STMT_PARFOR(__) then
let range_str = dumpExp(range)
let alg_str = (statementLst |> e => dumpStatement(e) ;separator="\n")
let src_str = dumpSource(source)
<<
parfor <%iter%> in <%range_str%> loop
<%alg_str%>
end for<%src_str%>;
>>
end dumpParForStatement;

template dumpWhileStatement(DAE.Statement stmt)
::=
match stmt
Expand Down

0 comments on commit d9142e1

Please sign in to comment.