Skip to content

Commit

Permalink
[NB,NF] update algorithm handling (#10170)
Browse files Browse the repository at this point in the history
* [NB,NF] update algorithm handling

 - move input output detection to a later place where cref types have been flattened
 - use output for algorithm size
 - use input outout for algorithm adjacency matrix

* [NB] update iterator collection adjacency matrix

* [NB] check if emtpy before computing max

* [NB] skip empty for loops when lowering

* [NB] collect dependencies of for loops in algebraic loops
  • Loading branch information
kabdelhak committed Feb 8, 2023
1 parent 8832e77 commit 1998e69
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 104 deletions.
19 changes: 10 additions & 9 deletions OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Expand Up @@ -126,11 +126,15 @@ public
ComponentRef name;
Expression range;
algorithm
(names, ranges) := List.unzip(frames);
iter := match (names, ranges)
case ({name}, {range}) then SINGLE(name, range);
else NESTED(listArray(names), listArray(ranges));
end match;
if listEmpty(frames) then
iter := EMPTY();
else
(names, ranges) := List.unzip(frames);
iter := match (names, ranges)
case ({name}, {range}) then SINGLE(name, range);
else NESTED(listArray(names), listArray(ranges));
end match;
end if;
end fromFrames;

function getFrames
Expand Down Expand Up @@ -1487,10 +1491,7 @@ public
(names, ranges) := Iterator.getFrames(eqn.iter);
then List.zip(names, ranges);

else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because eqation is not a for-equation: \n"
+ Equation.toString(eqn)});
then fail();
else {};
end match;
end getForFrames;

Expand Down
14 changes: 12 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Expand Up @@ -489,6 +489,7 @@ public
algorithm
_ := match comp
local
Pointer<Equation> eqn_ptr;
ComponentRef cref;
list<ComponentRef> dependencies = {}, loop_vars = {}, tmp;
list<tuple<ComponentRef, list<ComponentRef>>> scalarized_dependencies;
Expand Down Expand Up @@ -531,7 +532,7 @@ public
else
cref := comp.var_cref;
end if;
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, iter, comp.eqn.indices);
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, iter, comp.eqn.indices, false);
for tpl in listReverse(scalarized_dependencies) loop
(cref, dependencies) := tpl;
updateDependencyMap(cref, dependencies, map, jacType);
Expand Down Expand Up @@ -566,7 +567,7 @@ public
else
cref := comp.var_cref;
end if;
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, iter, comp.eqn.indices);
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, iter, comp.eqn.indices, false);
for tpl in listReverse(scalarized_dependencies) loop
(cref, dependencies) := tpl;
updateDependencyMap(cref, dependencies, map, jacType);
Expand All @@ -584,6 +585,15 @@ public
for slice in strict.residual_eqns loop
// ToDo: does this work properly for arrays?
tmp := Equation.collectCrefs(Pointer.access(Slice.getT(slice)), function Slice.getDependentCrefCausalized(set = set));
eqn_ptr := Slice.getT(slice);
if Equation.isForEquation(eqn_ptr) then
// if its a for equation get all dependencies corresponding to their residual.
// we do not really care for order and assume full dependency anyway
eqn as Equation.FOR_EQUATION(iter = iter, body = {body}) := Pointer.access(eqn_ptr);
cref := Equation.getEqnName(eqn_ptr);
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(cref, tmp, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, iter, slice.indices, true);
tmp := List.flatten(list(Util.tuple22(tpl) for tpl in scalarized_dependencies));
end if;
dependencies := listAppend(tmp, dependencies);
end for;

Expand Down
4 changes: 2 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo
Expand Up @@ -522,10 +522,10 @@ public
"Returns true, if the variable is a dummy variable.
Note: !Only works in the backend, will return true for any variable if used
during frontend!"
input Variable var;
input Pointer<Variable> var;
output Boolean isDummy;
algorithm
isDummy := match var
isDummy := match Pointer.access(var)
case NFVariable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.FRONTEND_DUMMY())) then true;
else false;
end match;
Expand Down
77 changes: 39 additions & 38 deletions OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -739,50 +739,52 @@ protected
EquationAttributes attr;

case FEquation.ARRAY_EQUALITY(lhs = lhs, rhs = rhs, ty = ty, source = source)
guard(Type.isArray(ty))
algorithm
attr := lowerEquationAttributes(ty, init);
//ToDo! How to get Record size and replace NONE()?
guard(Type.isArray(ty)) algorithm
attr := lowerEquationAttributes(ty, init);
//ToDo! How to get Record size and replace NONE()?
then {Pointer.create(BEquation.ARRAY_EQUATION(ty, lhs, rhs, source, attr, NONE()))};

// sometimes regular equalities are array equations aswell. Need to update frontend?
case FEquation.EQUALITY(lhs = lhs, rhs = rhs, ty = ty, source = source)
guard(Type.isArray(ty))
algorithm
attr := lowerEquationAttributes(ty, init);
//ToDo! How to get Record size and replace NONE()?
guard(Type.isArray(ty)) algorithm
attr := lowerEquationAttributes(ty, init);
//ToDo! How to get Record size and replace NONE()?
then {Pointer.create(BEquation.ARRAY_EQUATION(ty, lhs, rhs, source, attr, NONE()))};

case FEquation.EQUALITY(lhs = lhs, rhs = rhs, ty = ty, source = source)
algorithm
attr := lowerEquationAttributes(ty, init);
result := if Type.isComplex(ty) then {Pointer.create(BEquation.RECORD_EQUATION(ty, lhs, rhs, source, attr))}
else {Pointer.create(BEquation.SCALAR_EQUATION(ty, lhs, rhs, source, attr))};
case FEquation.EQUALITY(lhs = lhs, rhs = rhs, ty = ty, source = source) algorithm
attr := lowerEquationAttributes(ty, init);
result := if Type.isComplex(ty) then {Pointer.create(BEquation.RECORD_EQUATION(ty, lhs, rhs, source, attr))}
else {Pointer.create(BEquation.SCALAR_EQUATION(ty, lhs, rhs, source, attr))};
then result;

case FEquation.FOR(range = SOME(range))
algorithm
// Treat each body equation individually because they can have different equation attributes
// E.g.: DISCRETE, EvalStages

iterator := ComponentRef.fromNode(frontend_equation.iterator, Type.INTEGER(), {}, NFComponentRef.Origin.ITERATOR);
for eq in frontend_equation.body loop
new_body := listAppend(lowerEquation(eq, init), new_body);
end for;
for body_elem_ptr in new_body loop
body_elem := Pointer.access(body_elem_ptr);
body_elem := BEquation.FOR_EQUATION(
ty = Type.liftArrayLeftList(Equation.getType(body_elem), {Dimension.fromRange(range)}),
iter = Iterator.SINGLE(iterator, range),
body = {body_elem},
source = frontend_equation.source,
attr = Equation.getAttributes(body_elem)
);

// merge iterators of each for equation instead of having nested loops (for {i in 1:10, j in 1:3, k in 1:5})
Pointer.update(body_elem_ptr, Equation.mergeIterators(body_elem));
result := body_elem_ptr :: result;
end for;
case FEquation.FOR(range = SOME(range)) algorithm
if Expression.rangeSize(range) > 0 then
// Treat each body equation individually because they can have different equation attributes
// E.g.: DISCRETE, EvalStages

iterator := ComponentRef.fromNode(frontend_equation.iterator, Type.INTEGER(), {}, NFComponentRef.Origin.ITERATOR);
for eq in frontend_equation.body loop
new_body := listAppend(lowerEquation(eq, init), new_body);
end for;
for body_elem_ptr in new_body loop
body_elem := Pointer.access(body_elem_ptr);
body_elem := BEquation.FOR_EQUATION(
ty = Type.liftArrayLeftList(Equation.getType(body_elem), {Dimension.fromRange(range)}),
iter = Iterator.SINGLE(iterator, range),
body = {body_elem},
source = frontend_equation.source,
attr = Equation.getAttributes(body_elem)
);

// merge iterators of each for equation instead of having nested loops (for {i in 1:10, j in 1:3, k in 1:5})
Pointer.update(body_elem_ptr, Equation.mergeIterators(body_elem));
result := body_elem_ptr :: result;
end for;
else
if Flags.isSet(Flags.FAILTRACE) then
Error.addMessage(Error.COMPILER_WARNING,{getInstanceName() + ": Empty for-equation got removed:\n" + FEquation.toString(frontend_equation)});
end if;
end if;
then result;

// if equation
Expand Down Expand Up @@ -1078,8 +1080,7 @@ protected
algorithm
// ToDo! check if always DAE.EXPAND() can be used
// ToDo! export inputs
// ToDo! get array sizes instead of only list length
size := listLength(alg.outputs);
size := sum(ComponentRef.size(out) for out in alg.outputs);
attr := if init then NBEquation.EQ_ATTR_DEFAULT_INITIAL
elseif ComponentRef.listHasDiscrete(alg.outputs) then NBEquation.EQ_ATTR_DEFAULT_DISCRETE
else NBEquation.EQ_ATTR_DEFAULT_DYNAMIC;
Expand Down
17 changes: 12 additions & 5 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBAdjacency.mo
Expand Up @@ -839,8 +839,11 @@ public
Pointer<Equation> derivative;
algorithm
eqn := Pointer.access(eqn_ptr);
// possibly adapt for algorithms
dependencies := BEquation.Equation.collectCrefs(eqn, function Slice.getDependentCref(map = map, pseudo = pseudo));

dependencies := match eqn
case Equation.ALGORITHM() then list(cref for cref guard(UnorderedMap.contains(cref, map)) in listAppend(eqn.alg.inputs, eqn.alg.outputs));
else Equation.collectCrefs(eqn, function Slice.getDependentCref(map = map, pseudo = pseudo));
end match;
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));

if (st < MatrixStrictness.FULL) then
Expand Down Expand Up @@ -893,8 +896,8 @@ public
Integer eqn_scal_idx, eqn_size;
list<ComponentRef> unique_dependencies;
algorithm
// ToDo: maybe bottleneck! test this for efficiency
unique_dependencies := List.uniqueOnTrue(list(ComponentRef.simplifySubscripts(dep) for dep in dependencies), ComponentRef.isEqual);
unique_dependencies := list(ComponentRef.simplifySubscripts(dep) for dep in dependencies);
unique_dependencies := UnorderedSet.unique_list(unique_dependencies, ComponentRef.hash, ComponentRef.isEqual);
_ := match (eqn, mapping_opt)
local
Mapping mapping;
Expand All @@ -915,7 +918,11 @@ public
then ();

case (Equation.ALGORITHM(), SOME(mapping)) guard(pseudo) algorithm
fillMatrixArray(unique_dependencies, map, mapping, eqn_arr_idx, m, modes, Slice.getDependentCrefIndicesPseudoArray);
(eqn_scal_idx, eqn_size) := mapping.eqn_AtS[eqn_arr_idx];
row := Slice.getDependentCrefIndicesPseudoScalar(unique_dependencies, map, mapping);
for i in 0:eqn_size-1 loop
arrayUpdate(m, eqn_scal_idx+i, listAppend(row, m[eqn_scal_idx+i]));
end for;
then ();

case (Equation.IF_EQUATION(), SOME(mapping)) guard(pseudo) algorithm
Expand Down

0 comments on commit 1998e69

Please sign in to comment.