Skip to content

Commit ac88b60

Browse files
authored
[NB] update StrongComponent.collectCrefs (#13504)
* [NB] update StrongComponent.collectCrefs - update cref collecting to replace exp iterators
1 parent 66b14b8 commit ac88b60

File tree

4 files changed

+86
-69
lines changed

4 files changed

+86
-69
lines changed

OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,13 +1271,15 @@ public
12711271
to a list of crefs. needs cref filter function."
12721272
input Equation eq;
12731273
input Slice.filterCref filter;
1274+
input MapFuncExpWrapper mapFunc = Expression.map;
12741275
output list<ComponentRef> cref_lst;
12751276
protected
12761277
UnorderedSet<ComponentRef> acc = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
12771278
algorithm
12781279
// map with the expression and cref filter functions
12791280
_ := map(eq, function Slice.filterExp(filter = filter, acc = acc),
1280-
SOME(function filter(acc = acc)));
1281+
SOME(function filter(acc = acc)),
1282+
mapFunc = mapFunc);
12811283
cref_lst := UnorderedSet.toList(acc);
12821284
end collectCrefs;
12831285

OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo

Lines changed: 58 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -549,24 +549,20 @@ public
549549

550550
// sliced array equations - create all the single entries
551551
case SINGLE_COMPONENT() guard(Equation.isArrayEquation(comp.eqn)) algorithm
552-
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set));
552+
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
553553
scalarized_dependencies := Slice.getDependentCrefsPseudoArrayCausalized(BVariable.getVarName(comp.var), dependencies);
554-
for tpl in scalarized_dependencies loop
555-
(cref, dependencies) := tpl;
556-
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
557-
updateDependencyMap(cref, deps_set, map);
558-
end for;
554+
addScalarizedDependencies(scalarized_dependencies, map, jacType);
559555
then ();
560556

561557
case SINGLE_COMPONENT() algorithm
562-
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set));
558+
dependencies := Equation.collectCrefs(Pointer.access(comp.eqn), function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
563559
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));
564560
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
565561
updateDependencyMap(BVariable.getVarName(comp.var), deps_set, map);
566562
then ();
567563

568564
case MULTI_COMPONENT() algorithm
569-
dependencies := Equation.collectCrefs(Pointer.access(Slice.getT(comp.eqn)), function Slice.getDependentCrefCausalized(set = set));
565+
dependencies := Equation.collectCrefs(Pointer.access(Slice.getT(comp.eqn)), function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
570566
dependencies := list(ComponentRef.stripIteratorSubscripts(dep) for dep in dependencies);
571567
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));
572568
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
@@ -579,88 +575,42 @@ public
579575

580576
// resizable for equations - create all the single entries
581577
case RESIZABLE_COMPONENT() guard(Equation.isForEquation(Slice.getT(comp.eqn))) algorithm
582-
eqn as Equation.FOR_EQUATION(iter = iter, body = {body}) := Pointer.access(Slice.getT(comp.eqn));
583-
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set));
584-
if ComponentRef.isEmpty(comp.var_cref) then
585-
Expression.CREF(cref = cref) := Equation.getLHS(body);
586-
else
587-
cref := comp.var_cref;
588-
end if;
589-
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(
590-
cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping,
591-
iter, Equation.size(Slice.getT(comp.eqn)), comp.eqn.indices, false);
592-
for tpl in listReverse(scalarized_dependencies) loop
593-
(cref, dependencies) := tpl;
594-
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
595-
updateDependencyMap(cref, deps_set, map);
596-
end for;
578+
addForLoopDependencies(Pointer.access(Slice.getT(comp.eqn)), comp.eqn.indices, comp.var_cref, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, map, set, jacType);
597579
then ();
598580

599581
// sliced for equations - create all the single entries
600582
case SLICED_COMPONENT() guard(Equation.isForEquation(Slice.getT(comp.eqn))) algorithm
601-
eqn as Equation.FOR_EQUATION(iter = iter, body = {body}) := Pointer.access(Slice.getT(comp.eqn));
602-
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set));
603-
if ComponentRef.isEmpty(comp.var_cref) then
604-
Expression.CREF(cref = cref) := Equation.getLHS(body);
605-
else
606-
cref := comp.var_cref;
607-
end if;
608-
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(
609-
cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping,
610-
iter, Equation.size(Slice.getT(comp.eqn)), comp.eqn.indices, false);
611-
for tpl in listReverse(scalarized_dependencies) loop
612-
(cref, dependencies) := tpl;
613-
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
614-
updateDependencyMap(cref, deps_set, map);
615-
end for;
583+
addForLoopDependencies(Pointer.access(Slice.getT(comp.eqn)), comp.eqn.indices, comp.var_cref, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, map, set, jacType);
616584
then ();
617585

618586
// sliced array equations - create all the single entries
619587
case SLICED_COMPONENT() guard(Equation.isArrayEquation(Slice.getT(comp.eqn))) algorithm
620588
eqn := Pointer.access(Slice.getT(comp.eqn));
621-
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set));
589+
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
622590
scalarized_dependencies := Slice.getDependentCrefsPseudoArrayCausalized(comp.var_cref, dependencies, comp.eqn.indices);
623-
for tpl in scalarized_dependencies loop
624-
(cref, dependencies) := tpl;
625-
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
626-
updateDependencyMap(cref, deps_set, map);
627-
end for;
591+
addScalarizedDependencies(scalarized_dependencies, map, jacType);
628592
then ();
629593

630594
// sliced regular equation.
631595
case SLICED_COMPONENT() algorithm
632596
eqn := Pointer.access(Slice.getT(comp.eqn));
633-
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set));
597+
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
634598
dependencies := List.flatten(list(ComponentRef.scalarizeAll(dep) for dep in dependencies));
635599
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
636600
updateDependencyMap(comp.var_cref, deps_set, map);
637601
then ();
638602

639603
// sliced for equations - create all the single entries
640604
case GENERIC_COMPONENT() guard(Equation.isForEquation(Slice.getT(comp.eqn))) algorithm
641-
eqn as Equation.FOR_EQUATION(iter = iter, body = {body}) := Pointer.access(Slice.getT(comp.eqn));
642-
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set));
643-
if ComponentRef.isEmpty(comp.var_cref) then
644-
Expression.CREF(cref = cref) := Equation.getLHS(body);
645-
else
646-
cref := comp.var_cref;
647-
end if;
648-
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(
649-
cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping,
650-
iter, Equation.size(Slice.getT(comp.eqn)), comp.eqn.indices, false);
651-
for tpl in listReverse(scalarized_dependencies) loop
652-
(cref, dependencies) := tpl;
653-
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
654-
updateDependencyMap(cref, deps_set, map);
655-
end for;
605+
addForLoopDependencies(Pointer.access(Slice.getT(comp.eqn)), comp.eqn.indices, comp.var_cref, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping, map, set, jacType);
656606
then ();
657607

658608
case ALGEBRAIC_LOOP(strict = strict) algorithm
659609
// traverse residual equations and collect dependencies
660610
deps_set := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
661611
for slice in strict.residual_eqns loop
662612
// ToDo: does this work properly for arrays?
663-
tmp := Equation.collectCrefs(Pointer.access(Slice.getT(slice)), function Slice.getDependentCrefCausalized(set = set));
613+
tmp := Equation.collectCrefs(Pointer.access(Slice.getT(slice)), function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
664614
eqn_ptr := Slice.getT(slice);
665615
if Equation.isForEquation(eqn_ptr) then
666616
// if its a for equation get all dependencies corresponding to their residual.
@@ -709,6 +659,53 @@ public
709659
end match;
710660
end collectCrefs;
711661

662+
function addScalarizedDependencies
663+
input list<tuple<ComponentRef, list<ComponentRef>>> scalarized_dependencies;
664+
input UnorderedMap<ComponentRef, list<ComponentRef>> map "unordered map to save the dependencies";
665+
input JacobianType jacType "sets the context";
666+
protected
667+
ComponentRef cref;
668+
list<ComponentRef> dependencies;
669+
UnorderedSet<ComponentRef> deps_set;
670+
algorithm
671+
for tpl in listReverse(scalarized_dependencies) loop
672+
(cref, dependencies) := tpl;
673+
deps_set := prepareDependencies(UnorderedSet.fromList(dependencies, ComponentRef.hash, ComponentRef.isEqual), map, jacType);
674+
updateDependencyMap(cref, deps_set, map);
675+
end for;
676+
end addScalarizedDependencies;
677+
678+
function addForLoopDependencies
679+
input Equation eqn;
680+
input list<Integer> indices;
681+
input ComponentRef var_cref;
682+
input VariablePointers var_rep "scalarized variable representatives";
683+
input VariablePointers eqn_rep "scalarized equation representatives";
684+
input Mapping var_rep_mapping "index mapping for variable representatives";
685+
input Mapping eqn_rep_mapping "index mapping for equation representatives";
686+
input UnorderedMap<ComponentRef, list<ComponentRef>> map "unordered map to save the dependencies";
687+
input UnorderedSet<ComponentRef> set "unordered set of array crefs to check for relevance (index lookup)";
688+
input JacobianType jacType "sets the context";
689+
protected
690+
Iterator iter;
691+
Equation body;
692+
list<ComponentRef> dependencies;
693+
ComponentRef cref;
694+
list<tuple<ComponentRef, list<ComponentRef>>> scalarized_dependencies;
695+
algorithm
696+
Equation.FOR_EQUATION(iter = iter, body = {body}) := eqn;
697+
dependencies := Equation.collectCrefs(eqn, function Slice.getDependentCrefCausalized(set = set), Expression.mapShallow);
698+
if ComponentRef.isEmpty(var_cref) then
699+
Expression.CREF(cref = cref) := Equation.getLHS(body);
700+
else
701+
cref := var_cref;
702+
end if;
703+
scalarized_dependencies := Slice.getDependentCrefsPseudoForCausalized(
704+
cref, dependencies, var_rep, eqn_rep, var_rep_mapping, eqn_rep_mapping,
705+
iter, Equation.size(Pointer.create(eqn)), indices, false);
706+
addScalarizedDependencies(scalarized_dependencies, map, jacType);
707+
end addForLoopDependencies;
708+
712709
function addLoopJacobian
713710
input output StrongComponent comp;
714711
input Option<BackendDAE> jac;

OMCompiler/Compiler/NBackEnd/Util/NBAdjacency.mo

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1645,7 +1645,7 @@ public
16451645
occ2 := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
16461646
filter := function Slice.getDependentCref(map = map, pseudo = true);
16471647
_ := Iterator.map(eqn.iter, function Slice.Slice.filterExp(filter = filter, acc = occ2),
1648-
SOME(function filter(acc = occ2)), Expression.map);
1648+
SOME(function filter(acc = occ2)), Expression.mapShallow);
16491649
// update unsolvables
16501650
Solvability.updateList(UnorderedSet.toList(occ2), Solvability.UNSOLVABLE(), sol_map);
16511651
then UnorderedSet.union(occ1, occ2);

OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ protected
3838
import Slice = NBSlice;
3939

4040
// NF imports
41+
import Call = NFCall;
4142
import ComplexType = NFComplexType;
4243
import ComponentRef = NFComponentRef;
4344
import Dimension = NFDimension;
@@ -232,8 +233,21 @@ public
232233
input UnorderedSet<ComponentRef> acc;
233234
algorithm
234235
() := match exp
236+
local
237+
Expression call_exp;
238+
Call call;
239+
235240
case Expression.CREF() algorithm filter(exp.cref, acc); then ();
236-
else ();
241+
case Expression.CALL(call = call as Call.TYPED_REDUCTION(exp = call_exp)) algorithm
242+
for iter in call.iters loop
243+
call_exp := Expression.replaceIterator(call_exp, Util.tuple21(iter), Util.tuple22(iter));
244+
end for;
245+
_ := Expression.mapShallow(call_exp, function filterExp(filter = filter, acc = acc));
246+
then ();
247+
248+
else algorithm
249+
_ := Expression.mapShallow(exp, function filterExp(filter = filter, acc = acc));
250+
then ();
237251
end match;
238252
end filterExp;
239253

@@ -313,9 +327,9 @@ public
313327
algorithm
314328
// put all unsolvable logic here!
315329
exp := match exp
316-
case Expression.RANGE() then Expression.map(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
317-
case Expression.LBINARY() then Expression.map(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
318-
case Expression.RELATION() then Expression.map(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
330+
case Expression.RANGE() then Expression.mapShallow(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
331+
case Expression.LBINARY() then Expression.mapShallow(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
332+
case Expression.RELATION() then Expression.mapShallow(exp, function filterExp(filter = function getDependentCref(map = map, pseudo = pseudo), acc = acc));
319333
else exp;
320334
end match;
321335
end getUnsolvableExpCrefs;
@@ -640,8 +654,12 @@ public
640654
algorithm
641655
row_cref_scal := ComponentRef.scalarizeSlice(row_cref, slice);
642656
dependencies_scal := list(ComponentRef.scalarizeSlice(dep, slice) for dep in dependencies);
643-
dependencies_scal := List.transposeList(dependencies_scal);
644-
tpl_lst := List.zip(row_cref_scal, dependencies_scal);
657+
if not listEmpty(dependencies_scal) then
658+
dependencies_scal := List.transposeList(dependencies_scal);
659+
tpl_lst := List.zip(row_cref_scal, dependencies_scal);
660+
else
661+
tpl_lst := list((cref, {}) for cref in row_cref_scal);
662+
end if;
645663
end getDependentCrefsPseudoArrayCausalized;
646664

647665
function locationToIndex

0 commit comments

Comments
 (0)