Skip to content

Commit

Permalink
Improve flattening without scalarization (#8214)
Browse files Browse the repository at this point in the history
- Flatten equations and algorithms with an empty prefix when not doing
  scalarization, to get rid of any remaining split indices.
  • Loading branch information
perost committed Nov 24, 2021
1 parent 7c3ddb9 commit 8213b2c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
35 changes: 22 additions & 13 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Expand Up @@ -136,6 +136,8 @@ uniontype FlattenSettings
end SETTINGS;
end FlattenSettings;

constant ComponentRef EMPTY_PREFIX = ComponentRef.EMPTY();

function flatten
input InstNode classInst;
input String name;
Expand Down Expand Up @@ -164,7 +166,7 @@ algorithm

deleted_vars := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);

(vars, sections) := flattenClass(InstNode.getClass(classInst), ComponentRef.EMPTY(),
(vars, sections) := flattenClass(InstNode.getClass(classInst), EMPTY_PREFIX,
Visibility.PUBLIC, NONE(), {}, sections, deleted_vars, settings);
vars := listReverseInPlace(vars);

Expand Down Expand Up @@ -736,10 +738,10 @@ algorithm
case Sections.SECTIONS()
algorithm
for eqn in listReverse(sects.equations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections);
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix, settings), sections);
end for;
for eqn in listReverse(sects.initialEquations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections, true);
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix, settings), sections, true);
end for;
for alg in listReverse(sects.algorithms) loop
sections := Sections.prependAlgorithm(vectorizeAlgorithm(alg, dimensions, prefix), sections);
Expand All @@ -752,12 +754,15 @@ algorithm
end vectorizeArray;

function vectorizeEquation
input Equation eqn;
input output Equation eqn;
input list<Dimension> dimensions;
input ComponentRef prefix;
output Equation veqn;
input FlattenSettings settings;
algorithm
veqn := match eqn
// Flatten with an empty prefix to get rid of any split indices.
{eqn} := flattenEquation(eqn, EMPTY_PREFIX, {}, settings);

eqn := match eqn
local
InstNode iter;
list<InstNode> iters;
Expand All @@ -775,31 +780,33 @@ algorithm
algorithm
(iters, ranges, subs) := makeIterators(prefix, dimensions);
subs := listReverseInPlace(subs);
veqn := Equation.mapExp(eqn, function addIterator(prefix = prefix, subscripts = subs));
eqn := Equation.mapExp(eqn, function addIterator(prefix = prefix, subscripts = subs));
src := Equation.source(eqn);

iter :: iters := iters;
range :: ranges := ranges;
veqn := Equation.FOR(iter, SOME(range), {veqn}, src);
eqn := Equation.FOR(iter, SOME(range), {eqn}, src);

while not listEmpty(iters) loop
iter :: iters := iters;
range :: ranges := ranges;
veqn := Equation.FOR(iter, SOME(range), {veqn}, src);
eqn := Equation.FOR(iter, SOME(range), {eqn}, src);
end while;
then
veqn;
eqn;

end match;
end vectorizeEquation;

function vectorizeAlgorithm
input Algorithm alg;
input output Algorithm alg;
input list<Dimension> dimensions;
input ComponentRef prefix;
output Algorithm valg;
algorithm
valg := match alg
// Flatten with an empty prefix to get rid of any split indices.
alg.statements := flattenStatements(alg.statements, EMPTY_PREFIX);

alg := match alg
local
InstNode iter;
list<InstNode> iters;
Expand Down Expand Up @@ -1200,6 +1207,7 @@ algorithm
case Equation.FOR()
algorithm
if settings.arrayConnect then
eq.body := flattenEquations(eq.body, EMPTY_PREFIX, settings);
eql := eq :: equations;
elseif not settings.scalarize then
eql := splitForLoop(eq, prefix, equations, settings);
Expand Down Expand Up @@ -1422,6 +1430,7 @@ protected
Equation eq;
algorithm
Equation.FOR(iter, range, body, src) := forLoop;
body := flattenEquations(body, EMPTY_PREFIX, settings);
(connects, non_connects) := splitForLoop2(body);

if not listEmpty(connects) then
Expand Down
24 changes: 12 additions & 12 deletions testsuite/flattening/modelica/scodeinst/ArrayConnect3.mo
Expand Up @@ -56,24 +56,24 @@ end ArrayConnect3;
// Real[1000, 100] cells.l.f;
// Real[1000, 100] cells.l.e;
// equation
// for $i1 in 1:999 loop
// for $i2 in 2:100 loop
// cells[$i1,$i2].l.e = cells[$i1,$i2 - 1].r.e;
// for $i1 in 2:1000 loop
// for $i2 in 1:99 loop
// cells[$i1,$i2].u.e = cells[$i1 - 1,$i2].d.e;
// end for;
// end for;
// for $i1 in 1:999 loop
// for $i2 in 1:99 loop
// cells[$i1,$i2].r.f + cells[$i1,$i2 + 1].l.f = 0.0;
// cells[$i1,$i2].d.f + cells[$i1 + 1,$i2].u.f = 0.0;
// end for;
// end for;
// for $i1 in 2:1000 loop
// for $i2 in 1:99 loop
// cells[$i1,$i2].u.e = cells[$i1 - 1,$i2].d.e;
// for $i1 in 1:999 loop
// for $i2 in 2:100 loop
// cells[$i1,$i2].l.e = cells[$i1,$i2 - 1].r.e;
// end for;
// end for;
// for $i1 in 1:999 loop
// for $i2 in 1:99 loop
// cells[$i1,$i2].d.f + cells[$i1 + 1,$i2].u.f = 0.0;
// cells[$i1,$i2].r.f + cells[$i1,$i2 + 1].l.f = 0.0;
// end for;
// end for;
// for $i1 in 1:1000 loop
Expand All @@ -83,13 +83,13 @@ end ArrayConnect3;
// cells[$i1,100].r.f + cells[$i1,1].l.f = 0.0;
// end for;
// for $i2 in 1:100 loop
// cells[1,$i2].u.e = S.p.e;
// end for;
// sum(cells[1,:].u.f) + S.p.f = 0.0;
// for $i2 in 1:100 loop
// cells[1000,$i2].d.e = S.n.e;
// end for;
// sum(cells[1000,:].d.f) + S.n.f = 0.0;
// for $i2 in 1:100 loop
// cells[1,$i2].u.e = S.p.e;
// end for;
// sum(cells[1,:].u.f) + S.p.f = 0.0;
// for $i1 in 1:999 loop
// cells[$i1,100].d.f = 0.0;
// end for;
Expand Down
Expand Up @@ -45,8 +45,8 @@ end PrintRecordTypes1;
// Real[100, 10] b.c.rb.y;
// equation
// for i in 1:100 loop
// b[i].c[1].rb = R1(/*Real*/(i), 0.0);
// b[i].c[10] = R2(R1(1.0, 0.0), R1(/*Real*/(i), 1.0));
// b[i].c[1].rb = R1(/*Real*/(i), 0.0);
// end for;
// end PrintRecordTypes1;
// endResult

0 comments on commit 8213b2c

Please sign in to comment.