Skip to content

Commit

Permalink
Improve iterator handling (#9022)
Browse files Browse the repository at this point in the history
- Instantiate/type iterators in the correct order when there are
  multiple.
- Improve evaluation of iterators with ranges that refer to other
  iterators.

Fixes #9019
  • Loading branch information
perost committed May 25, 2022
1 parent ff8602f commit 94a591f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 40 deletions.
15 changes: 6 additions & 9 deletions OMCompiler/Compiler/NFFrontEnd/NFCall.mo
Expand Up @@ -2103,7 +2103,7 @@ protected
InstNode iter, range_node;
Type ty;
algorithm
for i in inIters loop
for i in listReverse(inIters) loop
if isSome(i.range) then
range := Inst.instExp(Util.getOption(i.range), outScope, context, info);
else
Expand All @@ -2124,8 +2124,6 @@ protected
(outScope, iter) := Inst.addIteratorToScope(i.name, outScope, info, ty);
outIters := (iter, range) :: outIters;
end for;

outIters := listReverse(outIters);
end instIterators;

function typeArrayConstructor
Expand Down Expand Up @@ -2155,7 +2153,7 @@ protected
is_structural := not InstContext.inFunction(context);
next_context := InstContext.set(context, NFInstContext.SUBEXPRESSION);

for i in call.iters loop
for i in listReverse(call.iters) loop
(iter, range) := i;

if Expression.isEmpty(range) then
Expand All @@ -2169,12 +2167,13 @@ protected
iter_ty := Expression.typeOf(range);
end if;

dims := listAppend(Type.arrayDims(iter_ty), dims);
dims := List.append_reverse(Type.arrayDims(iter_ty), dims);
variability := Prefixes.variabilityMax(variability, iter_var);
purity := Prefixes.purityMin(purity, iter_pur);
iters := (iter, range) :: iters;
end for;
iters := listReverseInPlace(iters);

dims := listReverseInPlace(dims);

// InstContext.FOR is used here as a marker that this expression may contain iterators.
next_context := InstContext.set(next_context, NFInstContext.FOR);
Expand Down Expand Up @@ -2219,7 +2218,7 @@ protected
purity := Purity.PURE;
next_context := InstContext.set(context, NFInstContext.SUBEXPRESSION);

for i in call.iters loop
for i in listReverse(call.iters) loop
(iter, range) := i;

if Expression.isEmpty(range) then
Expand All @@ -2232,8 +2231,6 @@ protected
iters := (iter, range) :: iters;
end for;

iters := listReverseInPlace(iters);

// InstContext.FOR is used here as a marker that this expression may contain iterators.
next_context := InstContext.set(next_context, NFInstContext.FOR);
(arg, ty, exp_var, exp_pur) := Typing.typeExp(call.exp, next_context, info);
Expand Down
34 changes: 12 additions & 22 deletions OMCompiler/Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -299,6 +299,14 @@ algorithm
end match;
end evalExpOpt;

function evalExpPartialDefault
"Simplied version of evalExpPartial to work around MetaModelica issues with
default arguments and multiple return values when used as a function pointer."
input output Expression exp;
algorithm
exp := evalExpPartial(exp);
end evalExpPartialDefault;

function evalExpPartial
"Evaluates the parts of an expression that are possible to evaluate. This
means leaving parts of the expression that contains e.g. iterators or mutable
Expand Down Expand Up @@ -1841,14 +1849,14 @@ algorithm
case Call.TYPED_ARRAY_CONSTRUCTOR()
algorithm
c.exp := evalExpPartial(c.exp);
c.iters := list((Util.tuple21(i), evalExp(Util.tuple22(i), target)) for i in c.iters);
c.iters := Call.mapIteratorsExpShallow(c.iters, evalExpPartialDefault);
then
Expression.mapSplitExpressions(Expression.CALL(c), evalArrayConstructor);

case Call.TYPED_REDUCTION()
algorithm
c.exp := evalExpPartial(c.exp);
c.iters := list((Util.tuple21(i), evalExp(Util.tuple22(i), target)) for i in c.iters);
c.iters := Call.mapIteratorsExpShallow(c.iters, evalExpPartialDefault);
then
Expression.mapSplitExpressions(Expression.CALL(c), evalReduction);

Expand Down Expand Up @@ -3075,7 +3083,7 @@ protected
list<Type> types = {};
algorithm
Expression.CALL(call = Call.TYPED_ARRAY_CONSTRUCTOR(exp = exp, iters = iters)) := callExp;
(exp, ranges, iter_exps) := createIterationRanges(exp, iters);
(exp, ranges, iter_exps) := Expression.createIterationRanges(exp, iters);

// Precompute all the types we're going to need for the arrays created.
ty := Expression.typeOf(exp);
Expand All @@ -3087,25 +3095,6 @@ algorithm
result := evalArrayConstructor3(exp, ranges, iter_exps, types);
end evalArrayConstructor;

function createIterationRanges
input output Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output list<Expression> ranges = {};
output list<Mutable<Expression>> iters = {};
protected
InstNode node;
Expression range;
Mutable<Expression> iter;
algorithm
for i in iterators loop
(node, range) := i;
iter := Mutable.create(Expression.INTEGER(0));
exp := Expression.replaceIterator(exp, node, Expression.MUTABLE(iter));
iters := iter :: iters;
ranges := range :: ranges;
end for;
end createIterationRanges;

function evalArrayConstructor3
input Expression exp;
input list<Expression> ranges;
Expand All @@ -3126,6 +3115,7 @@ algorithm
result := evalExp(exp, EvalTarget.IGNORE_ERRORS());
else
range :: ranges_rest := ranges;
range := evalExp(range);
iter :: iters_rest := iterators;
ty :: rest_ty := types;
range_iter := ExpressionIterator.fromExp(range);
Expand Down
31 changes: 22 additions & 9 deletions OMCompiler/Compiler/NFFrontEnd/NFExpression.mo
Expand Up @@ -5366,6 +5366,26 @@ public
exp := ENUM_LITERAL(ty, Type.nthEnumLiteral(ty, n), n);
end nthEnumLiteral;

function createIterationRanges
input output Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output list<Expression> ranges = {};
output list<Mutable<Expression>> iters = {};
protected
InstNode node;
Expression range;
Mutable<Expression> iter;
algorithm
for i in iterators loop
(node, range) := i;
iter := Mutable.create(INTEGER(0));
ranges := list(replaceIterator(r, node, MUTABLE(iter)) for r in ranges);
exp := replaceIterator(exp, node, MUTABLE(iter));
iters := iter :: iters;
ranges := range :: ranges;
end for;
end createIterationRanges;

function foldReduction
input Expression exp;
input list<tuple<InstNode, Expression>> iterators;
Expand All @@ -5390,15 +5410,7 @@ public
list<Expression> ranges = {};
list<Mutable<Expression>> iters = {};
algorithm
e := exp;
for i in iterators loop
(node, range) := i;
iter := Mutable.create(INTEGER(0));
e := replaceIterator(e, node, MUTABLE(iter));
iters := iter :: iters;
ranges := range :: ranges;
end for;

(e, ranges, iters) := createIterationRanges(exp, iterators);
result := foldReduction2(e, ranges, iters, foldExp, mapFn, foldFn);
end foldReduction;

Expand Down Expand Up @@ -5431,6 +5443,7 @@ public
result := foldFn(foldExp, mapFn(exp));
else
range :: ranges_rest := ranges;
range := Ceval.evalExp(range);
iter :: iters_rest := iterators;
range_iter := ExpressionIterator.fromExp(range);
result := foldExp;
Expand Down
16 changes: 16 additions & 0 deletions testsuite/flattening/modelica/scodeinst/CevalReduction2.mo
@@ -0,0 +1,16 @@
// name: CevalReduction2
// keywords:
// status: correct
// cflags: -d=newInst
//
//

model CevalReduction2
constant Real x = sum(j for j in 1:i, i in 1:4);
end CevalReduction2;

// Result:
// class CevalReduction2
// constant Real x = 20.0;
// end CevalReduction2;
// endResult
2 changes: 2 additions & 0 deletions testsuite/flattening/modelica/scodeinst/Makefile
Expand Up @@ -171,6 +171,8 @@ CevalMul1.mo \
CevalNoEvent1.mo \
CevalOnes1.mo \
CevalProduct1.mo \
CevalReduction1.mo \
CevalReduction2.mo \
CevalRecord1.mo \
CevalRecord2.mo \
CevalRecord3.mo \
Expand Down

0 comments on commit 94a591f

Please sign in to comment.