Skip to content
This repository has been archived by the owner on May 18, 2019. It is now read-only.

[NF] Improve handling of reductions. #2703

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
239 changes: 150 additions & 89 deletions Compiler/NFFrontEnd/NFCall.mo

Large diffs are not rendered by default.

130 changes: 95 additions & 35 deletions Compiler/NFFrontEnd/NFCeval.mo
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import ExpressionIterator = NFExpressionIterator;
import MetaModelica.Dangerous.*;
import NFClass.Class;
import TypeCheck = NFTypeCheck;
import ExpandExp = NFExpandExp;

public
uniontype EvalTarget
Expand Down Expand Up @@ -577,32 +578,10 @@ function evalTypename
input Expression originExp;
input EvalTarget target;
output Expression exp;
protected
list<Expression> lits;
algorithm
// Only expand the typename into an array if it's used as a range, and keep
// them as typenames when used as e.g. dimensions.
if not EvalTarget.isRange(target) then
exp := originExp;
else
exp := match ty
case Type.ARRAY(elementType = Type.BOOLEAN())
then Expression.ARRAY(ty, {Expression.BOOLEAN(false), Expression.BOOLEAN(true)});

case Type.ARRAY(elementType = Type.ENUMERATION())
algorithm
lits := Expression.makeEnumLiterals(ty.elementType);
then
Expression.ARRAY(ty, lits);

else
algorithm
Error.addInternalError(getInstanceName() + " got invalid typename", sourceInfo());
then
fail();

end match;
end if;
exp := if EvalTarget.isRange(target) then ExpandExp.expandTypename(ty) else originExp;
end evalTypename;

function evalRange
Expand Down Expand Up @@ -1464,8 +1443,11 @@ algorithm
else
evalNormalCall(call.fn, args);

case Call.TYPED_MAP_CALL()
then evalReduction(call.exp, call.ty, call.iters);
case Call.TYPED_ARRAY_CONSTRUCTOR()
then evalArrayConstructor(call.exp, call.ty, call.iters);

case Call.TYPED_REDUCTION()
then evalReduction(call.fn, call.exp, call.ty, call.iters);

else
algorithm
Expand Down Expand Up @@ -2627,30 +2609,40 @@ algorithm
end match;
end evalSolverClock;

function evalReduction
function evalArrayConstructor
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
protected
Expression e = exp, range;
Expression e;
list<Expression> ranges;
list<Mutable<Expression>> iters;
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);
result := evalArrayConstructor2(e, ty, ranges, iters);
end evalArrayConstructor;

function createIterationRanges
input output Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output list<Expression> ranges = {};
output list<Mutable<Expression>> iters = {};
protected
InstNode node;
list<Expression> ranges = {}, expl;
Expression range;
Mutable<Expression> iter;
list<Mutable<Expression>> iters = {};
algorithm
for i in iterators loop
(node, range) := i;
iter := Mutable.create(Expression.INTEGER(0));
e := Expression.replaceIterator(e, node, Expression.MUTABLE(iter));
exp := Expression.replaceIterator(exp, node, Expression.MUTABLE(iter));
iters := iter :: iters;
ranges := evalExp(range) :: ranges;
end for;
end createIterationRanges;

result := evalReduction2(e, ty, ranges, iters);
end evalReduction;

function evalReduction2
function evalArrayConstructor2
input Expression exp;
input Type ty;
input list<Expression> ranges;
Expand All @@ -2676,11 +2668,79 @@ algorithm
while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
expl := evalReduction2(exp, el_ty, ranges_rest, iters_rest) :: expl;
expl := evalArrayConstructor2(exp, el_ty, ranges_rest, iters_rest) :: expl;
end while;

result := Expression.ARRAY(ty, listReverseInPlace(expl));
end if;
end evalArrayConstructor2;

partial function ReductionFn
input Expression exp1;
input Expression exp2;
output Expression result;
end ReductionFn;

function evalReduction
input Function fn;
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
protected
Expression e, default_exp;
list<Expression> ranges;
list<Mutable<Expression>> iters;
ReductionFn red_fn;
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);

(red_fn, default_exp) := match Absyn.pathString(Function.name(fn))
case "sum" then (evalBinaryAdd, Expression.makeZero(ty));
case "product" then (evalBinaryMul, Expression.makeOne(ty));
case "min" then (evalBuiltinMin2, Expression.makeMaxValue(ty));
case "max" then (evalBuiltinMax2, Expression.makeMinValue(ty));
else
algorithm
Error.assertion(false, getInstanceName() + " got unknown reduction function " +
Absyn.pathString(Function.name(fn)), sourceInfo());
then
fail();
end match;

result := evalReduction2(e, ranges, iters, default_exp, red_fn);
end evalReduction;

function evalReduction2
input Expression exp;
input list<Expression> ranges;
input list<Mutable<Expression>> iterators;
input Expression foldExp;
input ReductionFn fn;
output Expression result;
protected
Expression range;
list<Expression> ranges_rest, expl = {};
Mutable<Expression> iter;
list<Mutable<Expression>> iters_rest;
ExpressionIterator range_iter;
Expression value;
Type el_ty;
algorithm
if listEmpty(ranges) then
result := fn(foldExp, evalExp(exp));
else
range :: ranges_rest := ranges;
iter :: iters_rest := iterators;
range_iter := ExpressionIterator.fromExp(range);
result := foldExp;

while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
result := evalReduction2(exp, ranges_rest, iters_rest, result, fn);
end while;
end if;
end evalReduction2;

function evalSize
Expand Down
46 changes: 36 additions & 10 deletions Compiler/NFFrontEnd/NFExpandExp.mo
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public
(exp, expanded);

case Expression.ARRAY() then (exp, true);
case Expression.TYPENAME() then (expandTypename(exp.ty), true);
case Expression.RANGE() then expandRange(exp);
case Expression.CALL() then expandCall(exp.call, exp);
case Expression.SIZE() then expandSize(exp);
Expand Down Expand Up @@ -198,6 +199,31 @@ public
end match;
end expandCref4;

function expandTypename
input Type ty;
output Expression outExp;
algorithm
outExp := match ty
local
list<Expression> lits;

case Type.ARRAY(elementType = Type.BOOLEAN())
then Expression.ARRAY(ty, {Expression.BOOLEAN(false), Expression.BOOLEAN(true)});

case Type.ARRAY(elementType = Type.ENUMERATION())
algorithm
lits := Expression.makeEnumLiterals(ty.elementType);
then
Expression.ARRAY(ty, lits);

else
algorithm
Error.addInternalError(getInstanceName() + " got invalid typename", sourceInfo());
then
fail();
end match;
end expandTypename;

function expandRange
input Expression exp;
output Expression outExp;
Expand Down Expand Up @@ -227,8 +253,8 @@ public
guard Function.isBuiltin(call.fn) and not Function.isImpure(call.fn)
then expandBuiltinCall(call.fn, call.arguments, call);

case Call.TYPED_MAP_CALL()
then expandReduction(call.exp, call.ty, call.iters);
case Call.TYPED_ARRAY_CONSTRUCTOR()
then expandArrayConstructor(call.exp, call.ty, call.iters);

else expandGeneric(exp);
end matchcontinue;
Expand Down Expand Up @@ -340,7 +366,7 @@ public
end match;
end expandBuiltinGeneric2;

function expandReduction
function expandArrayConstructor
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
Expand All @@ -362,10 +388,10 @@ public
ranges := range :: ranges;
end for;

result := expandReduction2(e, ty, ranges, iters);
end expandReduction;
result := expandArrayConstructor2(e, ty, ranges, iters);
end expandArrayConstructor;

function expandReduction2
function expandArrayConstructor2
input Expression exp;
input Type ty;
input list<Expression> ranges;
Expand All @@ -384,8 +410,8 @@ public
// Normally it wouldn't be the expansion's task to simplify expressions,
// but we make an exception here since the generated expressions contain
// MUTABLE expressions that we need to get rid of. Also, expansion of
// reductions is often done during the scalarization phase, after the
// simplification phase, so they wouldn't otherwise be simplified.
// array constructors is often done during the scalarization phase, after
// the simplification phase, so they wouldn't otherwise be simplified.
result := expand(SimplifyExp.simplify(exp));
else
range :: ranges_rest := ranges;
Expand All @@ -396,12 +422,12 @@ public
while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
expl := expandReduction2(exp, el_ty, ranges_rest, iters_rest) :: expl;
expl := expandArrayConstructor2(exp, el_ty, ranges_rest, iters_rest) :: expl;
end while;

result := Expression.ARRAY(ty, listReverseInPlace(expl));
end if;
end expandReduction2;
end expandArrayConstructor2;

function expandSize
input Expression exp;
Expand Down