Skip to content

Commit

Permalink
[NF] Expand when conditions.
Browse files Browse the repository at this point in the history
  • Loading branch information
perost authored and OpenModelica-Hudson committed Apr 3, 2018
1 parent e74a420 commit 0616b65
Showing 1 changed file with 113 additions and 23 deletions.
136 changes: 113 additions & 23 deletions Compiler/NFFrontEnd/NFScalarize.mo
Expand Up @@ -51,6 +51,7 @@ import NFPrefixes.Visibility;
import List;
import ElementSource;
import DAE;
import Statement = NFStatement;

public
function scalarize
Expand All @@ -59,22 +60,17 @@ function scalarize
protected
list<Variable> vars = {};
list<Equation> eql = {}, ieql = {};
list<list<Statement>> alg = {}, ialg = {};
algorithm
for c in flatModel.variables loop
vars := scalarizeVariable(c, vars);
end for;

for eq in flatModel.equations loop
eql := scalarizeEquation(eq, eql);
end for;

for eq in flatModel.initialEquations loop
ieql := scalarizeEquation(eq, ieql);
end for;

flatModel.variables := listReverseInPlace(vars);
flatModel.equations := listReverseInPlace(eql);
flatModel.initialEquations := listReverseInPlace(ieql);
flatModel.equations := scalarizeEquations(flatModel.equations);
flatModel.initialEquations := scalarizeEquations(flatModel.initialEquations);
flatModel.algorithms := list(scalarizeAlgorithm(a) for a in flatModel.algorithms);
flatModel.initialAlgorithms := list(scalarizeAlgorithm(a) for a in flatModel.initialAlgorithms);

execStat(getInstanceName() + "(" + name + ")");
end scalarize;
Expand Down Expand Up @@ -220,7 +216,7 @@ algorithm
then scalarizeIfEquation(eq.branches, eq.source, equations);

case Equation.WHEN()
then Equation.WHEN(list(scalarizeBranch(b) for b in eq.branches), eq.source) :: equations;
then scalarizeWhenEquation(eq.branches, eq.source, equations);

else eq :: equations;
end match;
Expand All @@ -233,15 +229,15 @@ function scalarizeIfEquation
protected
list<tuple<Expression, list<Equation>>> bl = {};
Expression cond;
list<Equation> eql;
list<Equation> body;
algorithm
for b in branches loop
(cond, eql) := b;
eql := scalarizeEquations(eql);
(cond, body) := b;
body := scalarizeEquations(body);

// Remove branches with no equations after scalarization.
if not listEmpty(eql) then
bl := (cond, eql) :: bl;
if not listEmpty(body) then
bl := (cond, body) :: bl;
end if;
end for;

Expand All @@ -252,15 +248,109 @@ algorithm
end if;
end scalarizeIfEquation;

function scalarizeBranch
input output tuple<Expression, list<Equation>> branch;
function scalarizeWhenEquation
input list<tuple<Expression, list<Equation>>> branches;
input DAE.ElementSource source;
input output list<Equation> equations;
protected
Expression exp;
list<Equation> eql;
list<tuple<Expression, list<Equation>>> bl = {};
Expression cond;
list<Equation> body;
algorithm
for b in branches loop
(cond, body) := b;
body := scalarizeEquations(body);

if Type.isArray(Expression.typeOf(cond)) then
cond := Expression.expand(cond);
end if;

bl := (cond, body) :: bl;
end for;

equations := Equation.WHEN(listReverseInPlace(bl), source) :: equations;
end scalarizeWhenEquation;

function scalarizeAlgorithm
input list<Statement> stmts;
output list<Statement> statements = {};
algorithm
(exp, eql) := branch;
branch := (exp, scalarizeEquations(eql));
end scalarizeBranch;
for s in stmts loop
statements := scalarizeStatement(s, statements);
end for;

statements := listReverseInPlace(statements);
end scalarizeAlgorithm;

function scalarizeStatement
input Statement stmt;
input output list<Statement> statements;
algorithm
statements := match stmt
case Statement.FOR()
then Statement.FOR(stmt.iterator, scalarizeAlgorithm(stmt.body), stmt.source) :: statements;

case Statement.IF()
then scalarizeIfStatement(stmt.branches, stmt.source, statements);

case Statement.WHEN()
then scalarizeWhenStatement(stmt.branches, stmt.source, statements);

case Statement.WHILE()
then Statement.WHILE(stmt.condition, scalarizeAlgorithm(stmt.body), stmt.source) :: statements;

else stmt :: statements;
end match;
end scalarizeStatement;

function scalarizeIfStatement
input list<tuple<Expression, list<Statement>>> branches;
input DAE.ElementSource source;
input output list<Statement> statements;
protected
list<tuple<Expression, list<Statement>>> bl = {};
Expression cond;
list<Statement> body;
algorithm
for b in branches loop
(cond, body) := b;
body := scalarizeAlgorithm(body);

// Remove branches with no statements after scalarization.
if not listEmpty(body) then
bl := (cond, body) :: bl;
end if;
end for;

// Add the scalarized if statement to the list of statements unless we don't
// have any branches left.
if not listEmpty(bl) then
statements := Statement.IF(listReverseInPlace(bl), source) :: statements;
end if;
end scalarizeIfStatement;

function scalarizeWhenStatement
input list<tuple<Expression, list<Statement>>> branches;
input DAE.ElementSource source;
input output list<Statement> statements;
protected
list<tuple<Expression, list<Statement>>> bl = {};
Expression cond;
list<Statement> body;
algorithm
for b in branches loop
(cond, body) := b;
body := scalarizeAlgorithm(body);

if Type.isArray(Expression.typeOf(cond)) then
cond := Expression.expand(cond);
end if;

bl := (cond, body) :: bl;
end for;

statements := Statement.WHEN(listReverseInPlace(bl), source) :: statements;
end scalarizeWhenStatement;

annotation(__OpenModelica_Interface="frontend");
end NFScalarize;

0 comments on commit 0616b65

Please sign in to comment.