Skip to content

Commit

Permalink
[NF] Improve evaluation of reductions.
Browse files Browse the repository at this point in the history
- Evaluate iterations ranges before doing bindingExpMap since the
  evaluated ranges might contain binding expressions.
  • Loading branch information
perost committed Jun 12, 2020
1 parent 737ef3b commit 2134f9f
Showing 1 changed file with 35 additions and 46 deletions.
81 changes: 35 additions & 46 deletions OMCompiler/Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -1806,10 +1806,18 @@ algorithm
Expression.bindingExpMap(Expression.CALL(c), evalNormalCallExp);

case Call.TYPED_ARRAY_CONSTRUCTOR()
then evalArrayConstructor(c.exp, c.iters);
algorithm
c.exp := evalExpPartial(c.exp);
c.iters := list((Util.tuple21(i), evalExp_impl(Util.tuple22(i), target)) for i in c.iters);
then
Expression.bindingExpMap(Expression.CALL(c), evalArrayConstructor);

case Call.TYPED_REDUCTION()
then evalReduction(c.fn, c.exp, c.iters);
algorithm
c.exp := evalExpPartial(c.exp);
c.iters := list((Util.tuple21(i), evalExp_impl(Util.tuple22(i), target)) for i in c.iters);
then
Expression.bindingExpMap(Expression.CALL(c), evalReduction);

else
algorithm
Expand Down Expand Up @@ -3069,37 +3077,28 @@ algorithm
end evalBuiltinDynamicSelect;

function evalArrayConstructor
input Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
algorithm
result := evalExpPartial(exp);
result := Expression.bindingExpMap(result,
function evalArrayConstructor2(iterators = iterators));
end evalArrayConstructor;

function evalArrayConstructor2
input Expression exp;
input list<tuple<InstNode, Expression>> iterators;
input Expression callExp;
output Expression result;
protected
Expression e;
Expression exp;
list<tuple<InstNode, Expression>> iters;
list<Mutable<Expression>> iter_exps;
list<Expression> ranges;
list<Mutable<Expression>> iters;
list<Type> types = {};
Type ty;
list<Type> types = {};
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);
Expression.CALL(call = Call.TYPED_ARRAY_CONSTRUCTOR(exp = exp, iters = iters)) := callExp;
(exp, ranges, iter_exps) := createIterationRanges(exp, iters);

// Precompute all the types we're going to need for the arrays created.
ty := Expression.typeOf(e);
ty := Expression.typeOf(exp);
for r in ranges loop
ty := Type.liftArrayLeftList(ty, Type.arrayDims(Expression.typeOf(r)));
types := ty :: types;
end for;

result := evalArrayConstructor3(e, ranges, iters, types);
end evalArrayConstructor2;
result := evalArrayConstructor3(exp, ranges, iter_exps, types);
end evalArrayConstructor;

function createIterationRanges
input output Expression exp;
Expand All @@ -3116,7 +3115,7 @@ algorithm
iter := Mutable.create(Expression.INTEGER(0));
exp := Expression.replaceIterator(exp, node, Expression.MUTABLE(iter));
iters := iter :: iters;
ranges := evalExp_impl(range, EvalTarget.IGNORE_ERRORS()) :: ranges;
ranges := range :: ranges;
end for;
end createIterationRanges;

Expand Down Expand Up @@ -3161,30 +3160,20 @@ partial function ReductionFn
end ReductionFn;

function evalReduction
input Function fn;
input Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
algorithm
result := evalExpPartial(exp);
result := Expression.bindingExpMap(result,
function evalReduction2(fn = fn, iterators = iterators));
end evalReduction;

function evalReduction2
input Function fn;
input Expression exp;
input list<tuple<InstNode, Expression>> iterators;
input Expression callExp;
output Expression result;
protected
Expression e, default_exp;
Function fn;
Expression exp, default_exp;
list<tuple<InstNode, Expression>> iters;
list<Mutable<Expression>> iter_exps;
list<Expression> ranges;
list<Mutable<Expression>> iters;
ReductionFn red_fn;
Type ty;
ReductionFn red_fn;
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);
ty := Expression.typeOf(e);
Expression.CALL(call = Call.TYPED_REDUCTION(fn = fn, exp = exp, iters = iters)) := callExp;
(exp, ranges, iter_exps) := createIterationRanges(exp, iters);
ty := Expression.typeOf(exp);

(red_fn, default_exp) := match AbsynUtil.pathString(Function.name(fn))
case "sum" then (evalBinaryAdd, Expression.makeZero(ty));
Expand All @@ -3199,10 +3188,10 @@ algorithm
fail();
end match;

result := evalReduction3(e, ranges, iters, default_exp, red_fn);
end evalReduction2;
result := evalReduction2(exp, ranges, iter_exps, default_exp, red_fn);
end evalReduction;

function evalReduction3
function evalReduction2
input Expression exp;
input list<Expression> ranges;
input list<Mutable<Expression>> iterators;
Expand All @@ -3229,10 +3218,10 @@ algorithm
while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
result := evalReduction3(exp, ranges_rest, iters_rest, result, fn);
result := evalReduction2(exp, ranges_rest, iters_rest, result, fn);
end while;
end if;
end evalReduction3;
end evalReduction2;

function evalSize
input Expression exp;
Expand Down

0 comments on commit 2134f9f

Please sign in to comment.