Skip to content

Commit 845c8c9

Browse files
authored
[NB] refine FOR and IF nested simplification (#13666)
- add array/list as for-range support - refine error reporting and comments - [testsuite] add test
1 parent c50b4b4 commit 845c8c9

File tree

2 files changed

+228
-43
lines changed

2 files changed

+228
-43
lines changed

OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -682,21 +682,25 @@ public
682682
// get the only occuring iterator cref and solve the body for it
683683
cref := List.first(occs);
684684
(tmpEqn, _, status, invert) := Solve.solveBody(tmpEqn, cref, FunctionTree.EMPTY());
685+
operator := if invert == NBSolve.RelationInversion.TRUE then Operator.invert(condition.operator) else condition.operator;
685686

686687
// if its solvable, get the corresponding iterator range and adapt it with the information of the if-condition
687688
if status == NBSolve.Status.EXPLICIT and invert <> NBSolve.RelationInversion.UNKNOWN then
688689
range := UnorderedMap.getSafe(cref, iter_map, sourceInfo());
689-
(range, status) := match range
690-
case Expression.RANGE() algorithm
691-
operator := if invert == NBSolve.RelationInversion.TRUE then Operator.invert(condition.operator) else condition.operator;
692-
then (adaptRange(UnorderedMap.getSafe(cref, iter_map, sourceInfo()), Equation.getRHS(tmpEqn), operator.op), status);
690+
try
691+
(range, status) := match range
692+
case Expression.RANGE() then (adaptRange(UnorderedMap.getSafe(cref, iter_map, sourceInfo()), Equation.getRHS(tmpEqn), operator), status);
693693

694-
// ToDo: intercepting this
695-
case Expression.ARRAY() then (range, status);
694+
// ToDo: intercepting this
695+
case Expression.ARRAY() then (adaptArray(UnorderedMap.getSafe(cref, iter_map, sourceInfo()), Equation.getRHS(tmpEqn), operator), status);
696696

697-
// can't do anything here
698-
else (range, NBSolve.Status.UNSOLVABLE);
699-
end match;
697+
// can't do anything here
698+
else (range, NBSolve.Status.UNSOLVABLE);
699+
end match;
700+
else
701+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed to combine iterator: " + toString(iter) + " with condition " + Expression.toString(condition) + "."});
702+
fail();
703+
end try;
700704

701705
UnorderedMap.add(cref, range, iter_map);
702706
else
@@ -717,47 +721,47 @@ public
717721
function adaptRange
718722
input output Expression range;
719723
input Expression rhs;
720-
input Operator.Op op;
724+
input Operator operator;
721725
protected
722726
Integer thresh, start, step, stop;
723-
Type ty;
724727
Boolean within_range;
725728
algorithm
726-
(thresh, start, step, stop, ty) := match (rhs, range)
727-
case (Expression.INTEGER(thresh), range as Expression.RANGE(start = Expression.INTEGER(start), step = SOME(Expression.INTEGER(step)), stop = Expression.INTEGER(stop))) then (thresh, start, step, stop, range.ty);
728-
case (Expression.INTEGER(thresh), range as Expression.RANGE(start = Expression.INTEGER(start), stop = Expression.INTEGER(stop))) then (thresh, start, 1, stop, range.ty);
729-
else (0, 0, 0, 0, Type.UNKNOWN());
729+
// extract the primitive type representation
730+
(thresh, start, step, stop) := match (rhs, range)
731+
case (Expression.INTEGER(thresh), range as Expression.RANGE(start = Expression.INTEGER(start), step = SOME(Expression.INTEGER(step)), stop = Expression.INTEGER(stop))) then (thresh, start, step, stop);
732+
case (Expression.INTEGER(thresh), range as Expression.RANGE(start = Expression.INTEGER(start), stop = Expression.INTEGER(stop))) then (thresh, start, 1, stop);
733+
else algorithm
734+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because range could not be evaluated: " + Expression.toString(range)});
735+
then fail();
730736
end match;
731737

732-
if not Type.isUnknown(ty) then
733-
within_range := thresh * sign(step) > start * sign(step) and thresh * sign(step) < stop * sign(step);
734-
735-
range := match op
736-
// i == VAL as a condition
737-
case NFOperator.Op.EQUAL then
738-
// remove all but this element from the range
739-
if within_range then Expression.makeRange(Expression.INTEGER(thresh), NONE(), Expression.INTEGER(thresh))
740-
// this element is not in the range >>> no valid element
741-
else Expression.makeRange(Expression.INTEGER(0), SOME(Expression.INTEGER(0)), Expression.INTEGER(0));
742-
743-
// i <> VAL as a condition
744-
case NFOperator.Op.NEQUAL then
745-
// remove only this element from the range
746-
if within_range then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i <> thresh) in List.intRange3(start, step, stop))), Type.INTEGER(), true)
747-
// this element is not in the range >>> original range not changed
748-
else range;
749-
750-
// i <, <=, >, >= VAL as a condition
751-
case NFOperator.Op.LESS then interceptRange(thresh - 1, start, step, stop, within_range, sign(step) > 0, range, intLe);
752-
case NFOperator.Op.LESSEQ then interceptRange(thresh, start, step, stop, within_range, sign(step) > 0, range, intLt);
753-
case NFOperator.Op.GREATER then interceptRange(thresh + 1, start, step, stop, within_range, sign(step) < 0, range, intGe);
754-
case NFOperator.Op.GREATEREQ then interceptRange(thresh, start, step, stop, within_range, sign(step) < 0, range, intGt);
738+
within_range := thresh * sign(step) > start * sign(step) and thresh * sign(step) < stop * sign(step);
739+
740+
range := match operator.op
741+
// i == VAL as a condition
742+
case NFOperator.Op.EQUAL then
743+
// remove all but this element from the range
744+
if within_range then Expression.makeRange(Expression.INTEGER(thresh), NONE(), Expression.INTEGER(thresh))
745+
// this element is not in the range >>> no valid element
746+
else Expression.makeRange(Expression.INTEGER(0), SOME(Expression.INTEGER(0)), Expression.INTEGER(0));
747+
748+
// i <> VAL as a condition
749+
case NFOperator.Op.NEQUAL then
750+
// remove only this element from the range
751+
if within_range then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i <> thresh) in List.intRange3(start, step, stop))), Type.INTEGER(), true)
752+
// this element is not in the range >>> original range not changed
753+
else range;
754+
755+
// i <, <=, >, >= VAL as a condition
756+
case NFOperator.Op.LESS then interceptRange(thresh - 1, start, step, stop, within_range, sign(step) > 0, range, intLe);
757+
case NFOperator.Op.LESSEQ then interceptRange(thresh, start, step, stop, within_range, sign(step) > 0, range, intLt);
758+
case NFOperator.Op.GREATER then interceptRange(thresh + 1, start, step, stop, within_range, sign(step) < 0, range, intGe);
759+
case NFOperator.Op.GREATEREQ then interceptRange(thresh, start, step, stop, within_range, sign(step) < 0, range, intGt);
755760

756-
else algorithm
757-
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for an unknown reason."});
758-
then fail();
759-
end match;
760-
end if;
761+
else algorithm
762+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for operator: " + Operator.toDebugString(operator)});
763+
then fail();
764+
end match;
761765
end adaptRange;
762766

763767
function interceptRange
@@ -791,6 +795,43 @@ public
791795
end if;
792796
end interceptRange;
793797

798+
function adaptArray
799+
input output Expression array;
800+
input Expression rhs;
801+
input Operator operator;
802+
protected
803+
Integer thresh;
804+
list<Integer> elems;
805+
algorithm
806+
// extract the primitive type representation
807+
(thresh, elems) := match (rhs, array)
808+
case (Expression.INTEGER(thresh), Expression.ARRAY(literal = true)) then (thresh, list(Expression.integerValue(e) for e in array.elements));
809+
else algorithm
810+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because array range is non literal: " + Expression.toString(array)});
811+
then fail();
812+
end match;
813+
814+
array := match operator.op
815+
// i == VAL as a condition
816+
case NFOperator.Op.EQUAL then
817+
// remove all but this element from the array
818+
if List.contains(elems, thresh, intEq) then Expression.makeRange(Expression.INTEGER(thresh), NONE(), Expression.INTEGER(thresh))
819+
// this element is not in the range >>> no valid element
820+
else Expression.makeRange(Expression.INTEGER(0), SOME(Expression.INTEGER(0)), Expression.INTEGER(0));
821+
822+
// i <>, <, <=, >, >= VAL as a condition
823+
case NFOperator.Op.NEQUAL then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i <> thresh) in elems)), Type.INTEGER(), true);
824+
case NFOperator.Op.LESS then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i < thresh) in elems)), Type.INTEGER(), true);
825+
case NFOperator.Op.LESSEQ then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i <= thresh) in elems)), Type.INTEGER(), true);
826+
case NFOperator.Op.GREATER then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i > thresh) in elems)), Type.INTEGER(), true);
827+
case NFOperator.Op.GREATEREQ then Expression.makeExpArray(listArray(list(Expression.INTEGER(i) for i guard(i >= thresh) in elems)), Type.INTEGER(), true);
828+
829+
else algorithm
830+
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for operator: " + Operator.toDebugString(operator)});
831+
then fail();
832+
end match;
833+
end adaptArray;
834+
794835
function toString
795836
input Iterator iter;
796837
output String str = "";

testsuite/simulation/modelica/NBackend/array_handling/for_if.mos

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,52 @@ equation
4343
end if;
4444
end for;
4545
end for_if;
46+
47+
model for_if_lists
48+
Real[5] x, y, z;
49+
equation
50+
// <> and ==
51+
for i in {1,2,3,4,5} loop
52+
if i <> 3 then
53+
x[i] = time + i;
54+
end if;
55+
end for;
56+
for i in {5,4,3,2,1} loop
57+
if i == 3 then
58+
x[i] = time;
59+
end if;
60+
end for;
61+
62+
// < and >=
63+
for i in {1,2,3,4,5} loop
64+
if i < 3 then
65+
y[i] = time + i;
66+
end if;
67+
end for;
68+
for i in {5,4,3,2,1} loop
69+
if i >= 3 then
70+
y[i] = time - 1;
71+
end if;
72+
end for;
73+
74+
// > and <=
75+
for i in {1,2,3,4,5} loop
76+
if i > 3 then
77+
z[i] = time + i;
78+
end if;
79+
end for;
80+
for i in {5,4,3,2,1} loop
81+
if i <= 3 then
82+
z[i] = time - 1;
83+
end if;
84+
end for;
85+
end for_if_lists;
4686
"); getErrorString();
4787

4888
setCommandLineOptions("--newBackend, -d=dumpSimplify");
4989

5090
simulate(for_if); getErrorString();
91+
simulate(for_if_lists); getErrorString();
5192

5293
// Result:
5394
// true
@@ -156,4 +197,107 @@ simulate(for_if); getErrorString();
156197
// "
157198
// end SimulationResult;
158199
// ""
200+
// ### dumpSimplify | ###
201+
// [BEFORE]
202+
// [FOR-] (5)
203+
// [----] for i in {1, 2, 3, 4, 5} loop
204+
// [----] [-IF-] (1)
205+
// [----] [----] if i <> 3 then
206+
// [----] [----] [SCAL] (1) x[i] = time + i
207+
// [----] [----] end if;
208+
// [----] end for;
209+
// [AFTER ]
210+
// [FOR-] (4)
211+
// [----] for i in {1, 2, 4, 5} loop
212+
// [----] [SCAL] (1) x[i] = time + i
213+
// [----] end for;
214+
//
215+
// ### dumpSimplify | ###
216+
// [BEFORE]
217+
// [FOR-] (5)
218+
// [----] for i in {5, 4, 3, 2, 1} loop
219+
// [----] [-IF-] (1)
220+
// [----] [----] if i == 3 then
221+
// [----] [----] [SCAL] (1) x[i] = time
222+
// [----] [----] end if;
223+
// [----] end for;
224+
// [AFTER ]
225+
// [SCAL] (1) x[3] = time
226+
//
227+
// ### dumpSimplify | ###
228+
// [BEFORE]
229+
// [FOR-] (5)
230+
// [----] for i in {1, 2, 3, 4, 5} loop
231+
// [----] [-IF-] (1)
232+
// [----] [----] if i < 3 then
233+
// [----] [----] [SCAL] (1) y[i] = time + i
234+
// [----] [----] end if;
235+
// [----] end for;
236+
// [AFTER ]
237+
// [FOR-] (2)
238+
// [----] for i in {1, 2} loop
239+
// [----] [SCAL] (1) y[i] = time + i
240+
// [----] end for;
241+
//
242+
// ### dumpSimplify | ###
243+
// [BEFORE] time - 1.0
244+
// [AFTER ] (-1.0) + time
245+
//
246+
// ### dumpSimplify | ###
247+
// [BEFORE]
248+
// [FOR-] (5)
249+
// [----] for i in {5, 4, 3, 2, 1} loop
250+
// [----] [-IF-] (1)
251+
// [----] [----] if i >= 3 then
252+
// [----] [----] [SCAL] (1) y[i] = (-1.0) + time
253+
// [----] [----] end if;
254+
// [----] end for;
255+
// [AFTER ]
256+
// [FOR-] (3)
257+
// [----] for i in {5, 4, 3} loop
258+
// [----] [SCAL] (1) y[i] = (-1.0) + time
259+
// [----] end for;
260+
//
261+
// ### dumpSimplify | ###
262+
// [BEFORE]
263+
// [FOR-] (5)
264+
// [----] for i in {1, 2, 3, 4, 5} loop
265+
// [----] [-IF-] (1)
266+
// [----] [----] if i > 3 then
267+
// [----] [----] [SCAL] (1) z[i] = time + i
268+
// [----] [----] end if;
269+
// [----] end for;
270+
// [AFTER ]
271+
// [FOR-] (2)
272+
// [----] for i in {4, 5} loop
273+
// [----] [SCAL] (1) z[i] = time + i
274+
// [----] end for;
275+
//
276+
// ### dumpSimplify | ###
277+
// [BEFORE] time - 1.0
278+
// [AFTER ] (-1.0) + time
279+
//
280+
// ### dumpSimplify | ###
281+
// [BEFORE]
282+
// [FOR-] (5)
283+
// [----] for i in {5, 4, 3, 2, 1} loop
284+
// [----] [-IF-] (1)
285+
// [----] [----] if i <= 3 then
286+
// [----] [----] [SCAL] (1) z[i] = (-1.0) + time
287+
// [----] [----] end if;
288+
// [----] end for;
289+
// [AFTER ]
290+
// [FOR-] (3)
291+
// [----] for i in {3, 2, 1} loop
292+
// [----] [SCAL] (1) z[i] = (-1.0) + time
293+
// [----] end for;
294+
//
295+
// record SimulationResult
296+
// resultFile = "for_if_lists_res.mat",
297+
// simulationOptions = "startTime = 0.0, stopTime = 1.0, numberOfIntervals = 500, tolerance = 1e-6, method = 'dassl', fileNamePrefix = 'for_if_lists', options = '', outputFormat = 'mat', variableFilter = '.*', cflags = '', simflags = ''",
298+
// messages = "LOG_SUCCESS | info | The initialization finished successfully without homotopy method.
299+
// LOG_SUCCESS | info | The simulation finished successfully.
300+
// "
301+
// end SimulationResult;
302+
// ""
159303
// endResult

0 commit comments

Comments
 (0)