Skip to content

Commit

Permalink
[NF] Clean up Flatten.
Browse files Browse the repository at this point in the history
- Pass the scalarize flag as a bool where it's needed instead of looking
  up the value of the flag over and over.
- Split the array vectorization off to a separate function to make the
  code cleaner and avoid having to check whether to run the code for
  every array element.
- Simplify binding handling in Flatten.flattenClass.
  • Loading branch information
perost committed Jun 2, 2020
1 parent 42e2e6f commit 4060689
Showing 1 changed file with 86 additions and 70 deletions.
156 changes: 86 additions & 70 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Expand Up @@ -134,14 +134,15 @@ protected
list<Algorithm> alg, ialg;
DAE.ElementSource src;
Option<SCode.Comment> cmt;
Boolean scalarize = Flags.isSet(Flags.NF_SCALARIZE);
algorithm
sections := Sections.EMPTY();
src := ElementSource.createElementSource(InstNode.info(classInst));
src := ElementSource.addCommentToSource(src,
SCodeUtil.getElementComment(InstNode.definition(classInst)));

(vars, sections) := flattenClass(InstNode.getClass(classInst), ComponentRef.EMPTY(),
Visibility.PUBLIC, NONE(), {}, sections);
Visibility.PUBLIC, NONE(), {}, sections, scalarize);
vars := listReverseInPlace(vars);

flatModel := match sections
Expand Down Expand Up @@ -182,12 +183,12 @@ function flattenClass
input Option<Binding> binding;
input output list<Variable> vars;
input output Sections sections;
input Boolean scalarize;
protected
array<InstNode> comps;
list<Binding> bindings;
list<Binding> bindings = {};
Binding b;
algorithm
// print(">" + stringAppendList(List.fill(" ", ComponentRef.depth(prefix)-1)) + ComponentRef.toString(prefix) + "\n");
() := match cls
case Class.INSTANCED_CLASS(elements = ClassTree.FLAT_TREE(components = comps))
algorithm
Expand All @@ -196,36 +197,29 @@ algorithm

if Binding.isBound(b) then
b := flattenBinding(b, ComponentRef.rest(prefix));
bindings := getRecordBindings(b, comps);

Error.assertion(listLength(bindings) == arrayLength(comps),
getInstanceName() + " got record binding with wrong number of elements for " +
ComponentRef.toString(prefix),
sourceInfo());

for c in comps loop
(vars, sections) := flattenComponent(c, prefix, visibility, SOME(listHead(bindings)), vars, sections);
bindings := listRest(bindings);
end for;
else
for c in comps loop
(vars, sections) := flattenComponent(c, prefix, visibility, binding, vars, sections);
end for;
bindings := getRecordBindings(b, comps, prefix);
end if;
end if;

if listEmpty(bindings) then
for c in comps loop
(vars, sections) := flattenComponent(c, prefix, visibility, binding, vars, sections, scalarize);
end for;
else
for c in comps loop
(vars, sections) := flattenComponent(c, prefix, visibility, NONE(), vars, sections);
b :: bindings := bindings;
(vars, sections) := flattenComponent(c, prefix, visibility, SOME(b), vars, sections, scalarize);
end for;
end if;

sections := flattenSections(cls.sections, prefix, sections);
sections := flattenSections(cls.sections, prefix, sections, scalarize);
then
();

case Class.TYPED_DERIVED()
algorithm
(vars, sections) :=
flattenClass(InstNode.getClass(cls.baseClass), prefix, visibility, binding, vars, sections);
flattenClass(InstNode.getClass(cls.baseClass), prefix, visibility, binding, vars, sections, scalarize);
then
();

Expand All @@ -238,7 +232,6 @@ algorithm
();

end match;
// print("<" + stringAppendList(List.fill(" ", ComponentRef.depth(prefix)-1)) + ComponentRef.toString(prefix) + "\n");
end flattenClass;

function flattenComponent
Expand All @@ -248,6 +241,7 @@ function flattenComponent
input Option<Binding> outerBinding;
input output list<Variable> vars;
input output Sections sections;
input Boolean scalarize;
protected
InstNode comp_node;
Component c;
Expand All @@ -264,8 +258,6 @@ algorithm
comp_node := InstNode.resolveOuter(component);
c := InstNode.component(comp_node);

// print("->" + stringAppendList(List.fill(" ", ComponentRef.depth(prefix))) + ComponentRef.toString(prefix) + "." + InstNode.name(component) + "\n");

() := match c
case Component.TYPED_COMPONENT(condition = condition, ty = ty)
algorithm
Expand All @@ -279,7 +271,8 @@ algorithm
vis := if InstNode.isProtected(component) then Visibility.PROTECTED else visibility;

if isComplexComponent(ty) then
(vars, sections) := flattenComplexComponent(comp_node, c, cls, ty, vis, outerBinding, prefix, vars, sections);
(vars, sections) := flattenComplexComponent(comp_node, c, cls, ty,
vis, outerBinding, prefix, vars, sections, scalarize);
else
(vars, sections) := flattenSimpleComponent(comp_node, c, vis, outerBinding,
Class.getTypeAttributes(cls), prefix, vars, sections);
Expand All @@ -296,8 +289,6 @@ algorithm
fail();

end match;

// print("<-" + stringAppendList(List.fill(" ", ComponentRef.depth(prefix))) + ComponentRef.toString(prefix) + "." + InstNode.name(component) + "\n");
end flattenComponent;

function isDeletedComponent
Expand Down Expand Up @@ -479,6 +470,7 @@ end isTypeAttributeNamed;
function getRecordBindings
input Binding binding;
input array<InstNode> comps;
input ComponentRef prefix;
output list<Binding> recordBindings = {};
protected
Expression binding_exp;
Expand All @@ -505,6 +497,11 @@ algorithm
then
fail();
end match;

Error.assertion(listLength(recordBindings) == arrayLength(comps),
getInstanceName() + " got record binding with wrong number of elements for " +
ComponentRef.toString(prefix),
sourceInfo());
end getRecordBindings;

function flattenComplexComponent
Expand All @@ -517,6 +514,7 @@ function flattenComplexComponent
input ComponentRef prefix;
input output list<Variable> vars;
input output Sections sections;
input Boolean scalarize;
protected
list<Dimension> dims;
ComponentRef name;
Expand Down Expand Up @@ -572,9 +570,11 @@ algorithm

// Flatten the class directly if the component is a scalar, otherwise scalarize it.
if listEmpty(dims) then
(vars, sections) := flattenClass(cls, name, visibility, opt_binding, vars, sections);
else
(vars, sections) := flattenClass(cls, name, visibility, opt_binding, vars, sections, scalarize);
elseif scalarize then
(vars, sections) := flattenArray(cls, dims, name, visibility, opt_binding, vars, sections);
else
(vars, sections) := vectorizeArray(cls, dims, name, visibility, opt_binding, vars, sections);
end if;
end flattenComplexComponent;

Expand All @@ -594,44 +594,13 @@ protected
RangeIterator range_iter;
Expression sub_exp;
list<Subscript> subs;
list<Variable> vrs;
Sections sects;
algorithm
// if we don't scalarize flatten the class and vectorize it
if not Flags.isSet(Flags.NF_SCALARIZE) then
(vrs, sects) := flattenClass(cls, prefix, visibility, binding, {}, Sections.SECTIONS({}, {}, {}, {}));
// add dimensions to the types
for v in vrs loop
v.ty := Type.liftArrayLeftList(v.ty, dimensions);
vars := v::vars;
end for;
// vectorize equations
() := match sects
case Sections.SECTIONS()
algorithm
for eqn in listReverse(sects.equations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections);
end for;
for eqn in listReverse(sects.initialEquations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections, true);
end for;
for alg in listReverse(sects.algorithms) loop
sections := Sections.prependAlgorithm(vectorizeAlgorithm(alg, dimensions, prefix), sections);
end for;
for alg in listReverse(sects.initialAlgorithms) loop
sections := Sections.prependAlgorithm(vectorizeAlgorithm(alg, dimensions, prefix), sections, true);
end for;
then ();
end match;
return;
end if;

if listEmpty(dimensions) then
subs := listReverse(subscripts);
sub_pre := ComponentRef.setSubscripts(subs, prefix);

(vars, sections) := flattenClass(cls, sub_pre, visibility,
subscriptBindingOpt(subs, binding), vars, sections);
subscriptBindingOpt(subs, binding), vars, sections, true);
else
dim :: rest_dims := dimensions;
range_iter := RangeIterator.fromDim(dim);
Expand All @@ -644,6 +613,48 @@ algorithm
end if;
end flattenArray;

function vectorizeArray
input Class cls;
input list<Dimension> dimensions;
input ComponentRef prefix;
input Visibility visibility;
input Option<Binding> binding;
input output list<Variable> vars;
input output Sections sections;
input list<Subscript> subscripts = {};
protected
list<Variable> vrs;
Sections sects;
algorithm
// if we don't scalarize flatten the class and vectorize it
(vrs, sects) := flattenClass(cls, prefix, visibility, binding, {}, Sections.SECTIONS({}, {}, {}, {}), false);

// add dimensions to the types
for v in vrs loop
v.ty := Type.liftArrayLeftList(v.ty, dimensions);
vars := v :: vars;
end for;

// vectorize equations
() := match sects
case Sections.SECTIONS()
algorithm
for eqn in listReverse(sects.equations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections);
end for;
for eqn in listReverse(sects.initialEquations) loop
sections := Sections.prependEquation(vectorizeEquation(eqn, dimensions, prefix), sections, true);
end for;
for alg in listReverse(sects.algorithms) loop
sections := Sections.prependAlgorithm(vectorizeAlgorithm(alg, dimensions, prefix), sections);
end for;
for alg in listReverse(sects.initialAlgorithms) loop
sections := Sections.prependAlgorithm(vectorizeAlgorithm(alg, dimensions, prefix), sections, true);
end for;
then ();
end match;
end vectorizeArray;

function vectorizeEquation
input Equation eqn;
input list<Dimension> dimensions;
Expand Down Expand Up @@ -935,6 +946,7 @@ function flattenSections
input Sections sections;
input ComponentRef prefix;
input output Sections accumSections;
input Boolean scalarize;
algorithm
() := match sections
local
Expand All @@ -943,8 +955,8 @@ algorithm

case Sections.SECTIONS()
algorithm
eq := flattenEquations(sections.equations, prefix);
ieq := flattenEquations(sections.initialEquations, prefix);
eq := flattenEquations(sections.equations, prefix, scalarize);
ieq := flattenEquations(sections.initialEquations, prefix, scalarize);
alg := flattenAlgorithms(sections.algorithms, prefix);
ialg := flattenAlgorithms(sections.initialAlgorithms, prefix);
accumSections := Sections.prepend(eq, ieq, alg, ialg, accumSections);
Expand All @@ -958,17 +970,19 @@ end flattenSections;
function flattenEquations
input list<Equation> eql;
input ComponentRef prefix;
input Boolean scalarize;
output list<Equation> equations = {};
algorithm
for eq in eql loop
equations := flattenEquation(eq, prefix, equations);
equations := flattenEquation(eq, prefix, equations, scalarize);
end for;
end flattenEquations;

function flattenEquation
input Equation eq;
input ComponentRef prefix;
input output list<Equation> equations;
input Boolean scalarize;
algorithm
equations := match eq
local
Expand All @@ -984,7 +998,7 @@ algorithm

case Equation.FOR()
algorithm
if Flags.isSet(Flags.NF_SCALARIZE) then
if scalarize then
eql := unrollForLoop(eq, prefix, equations);
else
eql := splitForLoop(eq, prefix, equations);
Expand All @@ -999,11 +1013,11 @@ algorithm
Equation.CONNECT(e1, e2, eq.source) :: equations;

case Equation.IF()
then flattenIfEquation(eq, prefix, equations);
then flattenIfEquation(eq, prefix, equations, scalarize);

case Equation.WHEN()
algorithm
eq.branches := list(flattenEqBranch(b, prefix) for b in eq.branches);
eq.branches := list(flattenEqBranch(b, prefix, scalarize) for b in eq.branches);
then
eq :: equations;

Expand Down Expand Up @@ -1042,6 +1056,7 @@ function flattenIfEquation
input Equation eq;
input ComponentRef prefix;
input output list<Equation> equations;
input Boolean scalarize;
protected
Equation.Branch branch;
list<Equation.Branch> branches, bl = {};
Expand Down Expand Up @@ -1070,7 +1085,7 @@ algorithm
algorithm
// Flatten the condition and body of the branch.
cond := flattenExp(cond, prefix);
eql := flattenEquations(eql, prefix);
eql := flattenEquations(eql, prefix, scalarize);

// Evaluate structural conditions.
if var <= Variability.STRUCTURAL_PARAMETER then
Expand Down Expand Up @@ -1149,14 +1164,15 @@ end isConnectEq;
function flattenEqBranch
input output Equation.Branch branch;
input ComponentRef prefix;
input Boolean scalarize;
protected
Expression exp;
list<Equation> eql;
Variability var;
algorithm
Equation.Branch.BRANCH(exp, var, eql) := branch;
exp := flattenExp(exp, prefix);
eql := flattenEquations(eql, prefix);
eql := flattenEquations(eql, prefix, scalarize);
branch := Equation.makeBranch(exp, listReverseInPlace(eql), var);
end flattenEqBranch;

Expand All @@ -1182,7 +1198,7 @@ algorithm
(range_iter, val) := RangeIterator.next(range_iter);
unrolled_body := Equation.mapExpList(body,
function Expression.replaceIterator(iterator = iter, iteratorValue = val));
unrolled_body := flattenEquations(unrolled_body, prefix);
unrolled_body := flattenEquations(unrolled_body, prefix, scalarize = true);
equations := listAppend(unrolled_body, equations);
end while;
end unrollForLoop;
Expand Down

0 comments on commit 4060689

Please sign in to comment.