Skip to content

Commit

Permalink
[NF] Evaluate cat and promote operators
Browse files Browse the repository at this point in the history
The cat operator calls ExpressionSimplify.evalCat which has been
rewritten to be a higher-order function that allows you to perform
cat on any tree structure (to reduce maintenance effort).

Belonging to [master]:
  - OpenModelica/OMCompiler#2243
  • Loading branch information
sjoelund authored and OpenModelica-Hudson committed Feb 28, 2018
1 parent 85e6162 commit 46705f7
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 25 deletions.
10 changes: 10 additions & 0 deletions Compiler/FrontEnd/Expression.mo
Expand Up @@ -3564,6 +3564,16 @@ algorithm
outArray := DAE.ARRAY(inType, inScalar, inElements);
end makeArray;

public function makeArrayFromList
input list<DAE.Exp> inElements;
output DAE.Exp outArray;
protected
DAE.Type ty;
algorithm
ty := typeof(listHead(inElements));
outArray := DAE.ARRAY(ty, not Types.isArray(ty), inElements);
end makeArrayFromList;

public function makeScalarArray
"Constructs an array of the given scalar type."
input list<DAE.Exp> inExpLst;
Expand Down
72 changes: 48 additions & 24 deletions Compiler/FrontEnd/ExpressionSimplify.mo
Expand Up @@ -1097,6 +1097,7 @@ algorithm
Real r1;
array<array<DAE.Exp>> marr;
String name, s1, s2;
list<Integer> dims;

// If the argument to min/max is an array, try to flatten it.
case (DAE.CALL(path=Absyn.IDENT(name),expLst={e as DAE.ARRAY()},
Expand Down Expand Up @@ -1378,7 +1379,8 @@ algorithm

case DAE.CALL(path=Absyn.IDENT("cat"),expLst=DAE.ICONST(i)::es,attr=DAE.CALL_ATTR(ty=tp))
algorithm
e := evalCat(i,es);
(es,dims) := evalCat(i, es, getArrayContents=Expression.getArrayOrMatrixContents, toString=ExpressionDump.printExpStr);
e := Expression.listToArray(es, list(DAE.DIM_INTEGER(d) for d in dims));
then e;

// promote n-dim to n-dim
Expand Down Expand Up @@ -1604,37 +1606,50 @@ algorithm
end matchcontinue;
end simplifyCat2;

protected function evalCat
public function evalCat<Exp>
input Integer dim;
input list<DAE.Exp> exps;
output DAE.Exp outExp;
input list<Exp> exps;
input GetArrayContents getArrayContents;
input ToString toString;
output list<Exp> outExps;
output list<Integer> outDims;
partial function GetArrayContents
input Exp e;
output list<Exp> es;
end GetArrayContents;
partial function MakeArrayFromList
input list<Exp> es;
output Exp e;
end MakeArrayFromList;
partial function ToString
input Exp e;
output String s;
end ToString;
protected
list<DAE.Exp> arr;
list<list<DAE.Exp>> arrs={};
list<Exp> arr;
list<list<Exp>> arrs={};
list<Integer> dims, firstDims={}, lastDims, reverseDims;
list<list<Integer>> dimsLst={};
DAE.Type tp;
Integer j, k, l, thisDim, lastDim;
array<DAE.Exp> expArr;
array<Exp> expArr;
algorithm
true := dim >= 1;
false := listEmpty(exps);
if 1 == dim then
arr := listAppend(Expression.getArrayOrMatrixContents(e) for e in exps);
tp := Expression.typeof(listHead(arr));
outExp := Expression.makeArray(arr, tp, not Types.isArray(tp));
outExps := listAppend(getArrayContents(e) for e in exps);
outDims := {listLength(outExps)};
return;
end if;
for e in listReverse(exps) loop
// Here we get a linear representation of all expressions in the array
// and the dimensions necessary to build up the array again
(arr,dims) := evalCatGetFlatArray(e,dim);
(arr,dims) := evalCatGetFlatArray(e, dim, getArrayContents=getArrayContents, toString=toString);
arrs := arr::arrs;
dimsLst := dims::dimsLst;
end for;
for i in 1:(dim-1) loop
j := min(listHead(d) for d in dimsLst);
Error.assertion(j == max(listHead(d) for d in dimsLst), getInstanceName() + ": cat got uneven dimensions for dim=" + String(i) + " " + ExpressionDump.printExpStr(DAE.LIST(exps)), sourceInfo());
Error.assertion(j == max(listHead(d) for d in dimsLst), getInstanceName() + ": cat got uneven dimensions for dim=" + String(i) + " " + stringDelimitList(list(toString(e) for e in exps), ", "), sourceInfo());
firstDims := j :: firstDims;
dimsLst := list(listRest(d) for d in dimsLst);
end for;
Expand All @@ -1658,33 +1673,42 @@ algorithm
end for;
// Convert the flat array structure to a tree array structure with the
// correct dimensions
arr := arrayList(expArr);
outExp := Expression.listToArray(arr, listReverse(DAE.DIM_INTEGER(d) for d in reverseDims));
outExps := arrayList(expArr);
outDims := listReverse(reverseDims);
end evalCat;

protected function evalCatGetFlatArray
input DAE.Exp e;
protected function evalCatGetFlatArray<Exp>
input Exp e;
input Integer dim;
output list<DAE.Exp> outExps={};
input GetArrayContents getArrayContents;
input ToString toString;
output list<Exp> outExps={};
output list<Integer> outDims={};
partial function GetArrayContents
input Exp e;
output list<Exp> es;
end GetArrayContents;
partial function ToString
input Exp e;
output String s;
end ToString;
protected
list<DAE.Exp> arr;
list<Exp> arr;
list<Integer> dims;
DAE.Type tp;
Integer i;
algorithm
if dim == 1 then
outExps := Expression.getArrayOrMatrixContents(e);
outExps := getArrayContents(e);
outDims := {listLength(outExps)};
return;
end if;
i := 0;
for exp in listReverse(Expression.getArrayOrMatrixContents(e)) loop
(arr, dims) := evalCatGetFlatArray(exp, dim-1);
for exp in listReverse(getArrayContents(e)) loop
(arr, dims) := evalCatGetFlatArray(exp, dim-1, getArrayContents=getArrayContents, toString=toString);
if listEmpty(outDims) then
outDims := dims;
else
Error.assertion(valueEq(dims, outDims), getInstanceName() + ": Got unbalanced array from " + ExpressionDump.printExpStr(e), sourceInfo());
Error.assertion(valueEq(dims, outDims), getInstanceName() + ": Got unbalanced array from " + toString(e), sourceInfo());
end if;
outExps := listAppend(arr, outExps);
i := i+1;
Expand Down
67 changes: 66 additions & 1 deletion Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -43,6 +43,7 @@ import NFCall.Call;
import Dimension = NFDimension;
import Type = NFType;
import NFTyping.ExpOrigin;
import ExpressionSimplify;

protected
import NFFunction.Function;
Expand Down Expand Up @@ -344,7 +345,7 @@ algorithm
case "asin" then evalBuiltinAsin(listHead(args), target);
case "atan2" then evalBuiltinAtan2(args);
case "atan" then evalBuiltinAtan(listHead(args));
//case "cat" then evalBuiltinCat(args, target);
case "cat" then evalBuiltinCat(listHead(args), listRest(args), target);
case "ceil" then evalBuiltinCeil(listHead(args));
case "cosh" then evalBuiltinCosh(listHead(args));
case "cos" then evalBuiltinCos(listHead(args));
Expand All @@ -367,6 +368,7 @@ algorithm
case "noEvent" then listHead(args); // No events during ceval, just return the argument.
case "ones" then evalBuiltinOnes(args);
case "product" then evalBuiltinProduct(listHead(args));
case "promote" then evalBuiltinPromote(listGet(args,1),listGet(args,2));
case "rem" then evalBuiltinRem(args, target);
case "scalar" then evalBuiltinScalar(args);
case "sign" then evalBuiltinSign(listHead(args));
Expand Down Expand Up @@ -537,6 +539,30 @@ algorithm
end match;
end evalBuiltinAtan;

function evalBuiltinCat
input Expression argN;
input list<Expression> args;
input EvalTarget target;
output Expression result;
protected
Integer n, nd;
Type ty;
list<Expression> es;
list<Integer> dims;
algorithm
Expression.INTEGER(n) := argN;
ty := Expression.typeOf(listHead(args));
nd := Type.dimensionCount(ty);
if n > nd or n < 1 then
if EvalTarget.hasInfo(target) then
Error.addSourceMessage(Error.ARGUMENT_OUT_OF_RANGE, {String(n), "cat", "1 <= x <= " + String(nd)}, EvalTarget.getInfo(target));
end if;
fail();
end if;
(es,dims) := ExpressionSimplify.evalCat(n, args, getArrayContents=Expression.arrayElements, toString=Expression.toString);
result := Expression.arrayFromList(es, Type.arrayElementType(ty), list(Dimension.INTEGER(d) for d in dims));
end evalBuiltinCat;

function evalBuiltinCeil
input Expression arg;
output Expression result;
Expand Down Expand Up @@ -931,6 +957,45 @@ algorithm
end match;
end evalBuiltinProductReal;

function evalBuiltinPromote
input Expression arg, argN;
output Expression result;
protected
Integer n, numToPromote;
Type ty;
algorithm
Expression.INTEGER(n) := argN;
ty := Expression.typeOf(arg);
numToPromote := n - Type.dimensionCount(ty);
result := evalBuiltinPromoteWork(arg, n);
end evalBuiltinPromote;

function evalBuiltinPromoteWork
input Expression arg;
input Integer n;
output Expression result;
protected
Expression exp;
list<Expression> exps;
Type ty;
algorithm
Error.assertion(n >= 1, "Promote called with n<1", sourceInfo());
if n == 1 then
result := Expression.ARRAY(Type.liftArrayLeft(Expression.typeOf(arg),Dimension.INTEGER(1)), {arg});
return;
end if;
result := match arg
case Expression.ARRAY()
algorithm
(exps as (Expression.ARRAY(ty=ty)::_)) := list(evalBuiltinPromoteWork(e, n-1) for e in arg.elements);
then Expression.ARRAY(Type.liftArrayLeft(ty,Dimension.INTEGER(listLength(arg.elements))), exps);
else
algorithm
(exp as Expression.ARRAY(ty=ty)) := evalBuiltinPromoteWork(arg, n-1);
then Expression.ARRAY(Type.liftArrayLeft(ty,Dimension.INTEGER(1)), {exp});
end match;
end evalBuiltinPromoteWork;

function evalBuiltinRem
input list<Expression> args;
input EvalTarget target;
Expand Down

0 comments on commit 46705f7

Please sign in to comment.