Skip to content
This repository has been archived by the owner on May 18, 2019. It is now read-only.

Commit

Permalink
[NF] Improve handling of reductions.
Browse files Browse the repository at this point in the history
- Separate the handling of array constructors and reductions so
  reductions can be handled properly.
- Implement expansion of type names, to better handle enumeration names
  as iteration ranges.
- Expand enumeration type names in Expression.toDAE, so that they can be
  converted to DAE-form.
- Add missing case for enumeration literals in Expression.compare.
- Don't fill in "missing" subscripts in crefs with :, it interfers with
  reductions in some cases and doesn't seem to have any benefits.

Belonging to [master]:
  - #2703
  - OpenModelica/OpenModelica-testsuite#1047
  • Loading branch information
perost authored and OpenModelica-Hudson committed Oct 8, 2018
1 parent 63c9d14 commit f34bf81
Show file tree
Hide file tree
Showing 9 changed files with 462 additions and 177 deletions.
239 changes: 150 additions & 89 deletions Compiler/NFFrontEnd/NFCall.mo

Large diffs are not rendered by default.

130 changes: 95 additions & 35 deletions Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -60,6 +60,7 @@ import ExpressionIterator = NFExpressionIterator;
import MetaModelica.Dangerous.*;
import NFClass.Class;
import TypeCheck = NFTypeCheck;
import ExpandExp = NFExpandExp;

public
uniontype EvalTarget
Expand Down Expand Up @@ -577,32 +578,10 @@ function evalTypename
input Expression originExp;
input EvalTarget target;
output Expression exp;
protected
list<Expression> lits;
algorithm
// Only expand the typename into an array if it's used as a range, and keep
// them as typenames when used as e.g. dimensions.
if not EvalTarget.isRange(target) then
exp := originExp;
else
exp := match ty
case Type.ARRAY(elementType = Type.BOOLEAN())
then Expression.ARRAY(ty, {Expression.BOOLEAN(false), Expression.BOOLEAN(true)});

case Type.ARRAY(elementType = Type.ENUMERATION())
algorithm
lits := Expression.makeEnumLiterals(ty.elementType);
then
Expression.ARRAY(ty, lits);

else
algorithm
Error.addInternalError(getInstanceName() + " got invalid typename", sourceInfo());
then
fail();

end match;
end if;
exp := if EvalTarget.isRange(target) then ExpandExp.expandTypename(ty) else originExp;
end evalTypename;

function evalRange
Expand Down Expand Up @@ -1464,8 +1443,11 @@ algorithm
else
evalNormalCall(call.fn, args);

case Call.TYPED_MAP_CALL()
then evalReduction(call.exp, call.ty, call.iters);
case Call.TYPED_ARRAY_CONSTRUCTOR()
then evalArrayConstructor(call.exp, call.ty, call.iters);

case Call.TYPED_REDUCTION()
then evalReduction(call.fn, call.exp, call.ty, call.iters);

else
algorithm
Expand Down Expand Up @@ -2627,30 +2609,40 @@ algorithm
end match;
end evalSolverClock;

function evalReduction
function evalArrayConstructor
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
protected
Expression e = exp, range;
Expression e;
list<Expression> ranges;
list<Mutable<Expression>> iters;
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);
result := evalArrayConstructor2(e, ty, ranges, iters);
end evalArrayConstructor;

function createIterationRanges
input output Expression exp;
input list<tuple<InstNode, Expression>> iterators;
output list<Expression> ranges = {};
output list<Mutable<Expression>> iters = {};
protected
InstNode node;
list<Expression> ranges = {}, expl;
Expression range;
Mutable<Expression> iter;
list<Mutable<Expression>> iters = {};
algorithm
for i in iterators loop
(node, range) := i;
iter := Mutable.create(Expression.INTEGER(0));
e := Expression.replaceIterator(e, node, Expression.MUTABLE(iter));
exp := Expression.replaceIterator(exp, node, Expression.MUTABLE(iter));
iters := iter :: iters;
ranges := evalExp(range) :: ranges;
end for;
end createIterationRanges;

result := evalReduction2(e, ty, ranges, iters);
end evalReduction;

function evalReduction2
function evalArrayConstructor2
input Expression exp;
input Type ty;
input list<Expression> ranges;
Expand All @@ -2676,11 +2668,79 @@ algorithm
while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
expl := evalReduction2(exp, el_ty, ranges_rest, iters_rest) :: expl;
expl := evalArrayConstructor2(exp, el_ty, ranges_rest, iters_rest) :: expl;
end while;

result := Expression.ARRAY(ty, listReverseInPlace(expl));
end if;
end evalArrayConstructor2;

partial function ReductionFn
input Expression exp1;
input Expression exp2;
output Expression result;
end ReductionFn;

function evalReduction
input Function fn;
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
output Expression result;
protected
Expression e, default_exp;
list<Expression> ranges;
list<Mutable<Expression>> iters;
ReductionFn red_fn;
algorithm
(e, ranges, iters) := createIterationRanges(exp, iterators);

(red_fn, default_exp) := match Absyn.pathString(Function.name(fn))
case "sum" then (evalBinaryAdd, Expression.makeZero(ty));
case "product" then (evalBinaryMul, Expression.makeOne(ty));
case "min" then (evalBuiltinMin2, Expression.makeMaxValue(ty));
case "max" then (evalBuiltinMax2, Expression.makeMinValue(ty));
else
algorithm
Error.assertion(false, getInstanceName() + " got unknown reduction function " +
Absyn.pathString(Function.name(fn)), sourceInfo());
then
fail();
end match;

result := evalReduction2(e, ranges, iters, default_exp, red_fn);
end evalReduction;

function evalReduction2
input Expression exp;
input list<Expression> ranges;
input list<Mutable<Expression>> iterators;
input Expression foldExp;
input ReductionFn fn;
output Expression result;
protected
Expression range;
list<Expression> ranges_rest, expl = {};
Mutable<Expression> iter;
list<Mutable<Expression>> iters_rest;
ExpressionIterator range_iter;
Expression value;
Type el_ty;
algorithm
if listEmpty(ranges) then
result := fn(foldExp, evalExp(exp));
else
range :: ranges_rest := ranges;
iter :: iters_rest := iterators;
range_iter := ExpressionIterator.fromExp(range);
result := foldExp;

while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
result := evalReduction2(exp, ranges_rest, iters_rest, result, fn);
end while;
end if;
end evalReduction2;

function evalSize
Expand Down
46 changes: 36 additions & 10 deletions Compiler/NFFrontEnd/NFExpandExp.mo
Expand Up @@ -68,6 +68,7 @@ public
(exp, expanded);

case Expression.ARRAY() then (exp, true);
case Expression.TYPENAME() then (expandTypename(exp.ty), true);
case Expression.RANGE() then expandRange(exp);
case Expression.CALL() then expandCall(exp.call, exp);
case Expression.SIZE() then expandSize(exp);
Expand Down Expand Up @@ -198,6 +199,31 @@ public
end match;
end expandCref4;

function expandTypename
input Type ty;
output Expression outExp;
algorithm
outExp := match ty
local
list<Expression> lits;

case Type.ARRAY(elementType = Type.BOOLEAN())
then Expression.ARRAY(ty, {Expression.BOOLEAN(false), Expression.BOOLEAN(true)});

case Type.ARRAY(elementType = Type.ENUMERATION())
algorithm
lits := Expression.makeEnumLiterals(ty.elementType);
then
Expression.ARRAY(ty, lits);

else
algorithm
Error.addInternalError(getInstanceName() + " got invalid typename", sourceInfo());
then
fail();
end match;
end expandTypename;

function expandRange
input Expression exp;
output Expression outExp;
Expand Down Expand Up @@ -227,8 +253,8 @@ public
guard Function.isBuiltin(call.fn) and not Function.isImpure(call.fn)
then expandBuiltinCall(call.fn, call.arguments, call);

case Call.TYPED_MAP_CALL()
then expandReduction(call.exp, call.ty, call.iters);
case Call.TYPED_ARRAY_CONSTRUCTOR()
then expandArrayConstructor(call.exp, call.ty, call.iters);

else expandGeneric(exp);
end matchcontinue;
Expand Down Expand Up @@ -340,7 +366,7 @@ public
end match;
end expandBuiltinGeneric2;

function expandReduction
function expandArrayConstructor
input Expression exp;
input Type ty;
input list<tuple<InstNode, Expression>> iterators;
Expand All @@ -362,10 +388,10 @@ public
ranges := range :: ranges;
end for;

result := expandReduction2(e, ty, ranges, iters);
end expandReduction;
result := expandArrayConstructor2(e, ty, ranges, iters);
end expandArrayConstructor;

function expandReduction2
function expandArrayConstructor2
input Expression exp;
input Type ty;
input list<Expression> ranges;
Expand All @@ -384,8 +410,8 @@ public
// Normally it wouldn't be the expansion's task to simplify expressions,
// but we make an exception here since the generated expressions contain
// MUTABLE expressions that we need to get rid of. Also, expansion of
// reductions is often done during the scalarization phase, after the
// simplification phase, so they wouldn't otherwise be simplified.
// array constructors is often done during the scalarization phase, after
// the simplification phase, so they wouldn't otherwise be simplified.
result := expand(SimplifyExp.simplify(exp));
else
range :: ranges_rest := ranges;
Expand All @@ -396,12 +422,12 @@ public
while ExpressionIterator.hasNext(range_iter) loop
(range_iter, value) := ExpressionIterator.next(range_iter);
Mutable.update(iter, value);
expl := expandReduction2(exp, el_ty, ranges_rest, iters_rest) :: expl;
expl := expandArrayConstructor2(exp, el_ty, ranges_rest, iters_rest) :: expl;
end while;

result := Expression.ARRAY(ty, listReverseInPlace(expl));
end if;
end expandReduction2;
end expandArrayConstructor2;

function expandSize
input Expression exp;
Expand Down

0 comments on commit f34bf81

Please sign in to comment.