Skip to content

Commit

Permalink
[NF] Generate less stupid flat models.
Browse files Browse the repository at this point in the history
- Improved the expression simplification, and used it to simplify more
  things in SimplifyModel.
- Implemented constant evaluation of matrix and symmetric.
- Added squareness check to argument of symmetric.
- Improved DAE conversion of if-statements so that 'else' is created
  instead of 'elseif true then' for the last branch.

Belonging to [master]:
  - OpenModelica/OMCompiler#2434
  - OpenModelica/OpenModelica-testsuite#944
  • Loading branch information
perost authored and OpenModelica-Hudson committed May 15, 2018
1 parent fb58909 commit c097132
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 338 deletions.
6 changes: 3 additions & 3 deletions Compiler/NFFrontEnd/NFCall.mo
Expand Up @@ -2242,15 +2242,15 @@ protected

if listLength(args) <> 1 then
Error.addSourceMessageAndFail(Error.NO_MATCHING_FUNCTION_FOUND_NFINST,
{toString(call), "symmetric(Any[n, m]) => Any[n, m]"}, info);
{toString(call), "symmetric(Any[n, n]) => Any[n, n]"}, info);
end if;

(arg, ty, variability) := Typing.typeExp(listHead(args), origin, info);

if not Type.isMatrix(ty) then
if not Type.isSquareMatrix(ty) then
Error.addSourceMessageAndFail(Error.ARG_TYPE_MISMATCH,
{"1", ComponentRef.toString(fn_ref), "", Expression.toString(arg),
Type.toString(ty), "Any[:, :]"}, info);
Type.toString(ty), "Any[n, n]"}, info);
end if;

{fn} := typeCachedFunctions(fn_ref);
Expand Down
101 changes: 90 additions & 11 deletions Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -658,7 +658,6 @@ function evalBinaryScalarProduct
input Expression exp1;
input Expression exp2;
output Expression exp;
protected
algorithm
exp := match (exp1, exp2)
local
Expand Down Expand Up @@ -1154,7 +1153,7 @@ algorithm
if Function.isBuiltin(call.fn) then
evalBuiltinCall(call.fn, args, target)
else
evalNormalCall(call.fn, args, call);
evalNormalCall(call.fn, args);

case Call.TYPED_MAP_CALL()
then evalReduction(call.exp, call.ty, call.iters);
Expand Down Expand Up @@ -1199,7 +1198,7 @@ algorithm
case "Integer" then evalBuiltinIntegerEnum(listHead(args));
case "log10" then evalBuiltinLog10(listHead(args), target);
case "log" then evalBuiltinLog(listHead(args), target);
//case "matrix" then evalBuiltinMatrix(args);
case "matrix" then evalBuiltinMatrix(listHead(args));
case "max" then evalBuiltinMax(args);
case "min" then evalBuiltinMin(args);
case "mod" then evalBuiltinMod(args);
Expand All @@ -1216,7 +1215,7 @@ algorithm
case "sqrt" then evalBuiltinSqrt(listHead(args));
case "String" then evalBuiltinString(args);
case "sum" then evalBuiltinSum(listHead(args));
//case "symmetric" then evalBuiltinSymmetric(args);
case "symmetric" then evalBuiltinSymmetric(listHead(args));
case "tanh" then evalBuiltinTanh(listHead(args));
case "tan" then evalBuiltinTan(listHead(args));
case "transpose" then evalBuiltinTranspose(listHead(args));
Expand All @@ -1232,6 +1231,12 @@ algorithm
end match;
end evalBuiltinCall;

function evalNormalCall
input Function fn;
input list<Expression> args;
output Expression result = EvalFunction.evaluate(fn, args);
end evalNormalCall;

protected

function printUnboundError
Expand Down Expand Up @@ -1265,13 +1270,6 @@ algorithm
end match;
end printUnboundError;

function evalNormalCall
input Function fn;
input list<Expression> args;
input Call call;
output Expression result = EvalFunction.evaluate(fn, args);
end evalNormalCall;

function printWrongArgsError
input String evalFunc;
input list<Expression> args;
Expand Down Expand Up @@ -1653,6 +1651,55 @@ algorithm
end match;
end evalBuiltinLog;

function evalBuiltinMatrix
input Expression arg;
output Expression result;
algorithm
result := match arg
local
Integer dim_count;
list<Expression> expl;
Dimension dim1, dim2;
Type ty;

case Expression.ARRAY(ty = ty)
algorithm
dim_count := Type.dimensionCount(ty);

if dim_count < 2 then
result := evalBuiltinPromoteWork(arg, 2);
elseif dim_count == 2 then
result := arg;
else
dim1 :: dim2 :: _ := Type.arrayDims(ty);
ty := Type.liftArrayLeft(Type.arrayElementType(ty), dim2);
expl := list(evalBuiltinMatrix2(e, ty) for e in arg.elements);
ty := Type.liftArrayLeft(ty, dim1);
result := Expression.ARRAY(ty, expl);
end if;
then
result;

case _ guard Type.isScalar(Expression.typeOf(arg))
then evalBuiltinPromoteWork(arg, 2);

else algorithm printWrongArgsError(getInstanceName(), {arg}, sourceInfo()); then fail();
end match;
end evalBuiltinMatrix;

function evalBuiltinMatrix2
input Expression arg;
input Type ty;
output Expression result;
algorithm
result := match arg
case Expression.ARRAY()
then Expression.ARRAY(ty, list(Expression.toScalar(e) for e in arg.elements));

else algorithm printWrongArgsError(getInstanceName(), {arg}, sourceInfo()); then fail();
end match;
end evalBuiltinMatrix2;

function evalBuiltinMax
input list<Expression> args;
output Expression result;
Expand Down Expand Up @@ -2030,6 +2077,38 @@ algorithm
end match;
end evalBuiltinSumReal;

function evalBuiltinSymmetric
input Expression arg;
output Expression result;
protected
array<array<Expression>> mat;
Integer n;
Type row_ty;
list<Expression> expl, accum = {};
algorithm
result := match arg
case Expression.ARRAY() guard Type.isMatrix(arg.ty)
algorithm
mat := listArray(list(listArray(Expression.arrayElements(row))
for row in Expression.arrayElements(arg)));
n := arrayLength(mat);
row_ty := Type.unliftArray(arg.ty);

for i in n:-1:1 loop
expl := {};
for j in n:-1:1 loop
expl := (if i > j then arrayGet(mat[j], i) else arrayGet(mat[i], j)) :: expl;
end for;

accum := Expression.ARRAY(row_ty, expl) :: accum;
end for;
then
Expression.ARRAY(arg.ty, accum);

else algorithm printWrongArgsError(getInstanceName(), {arg}, sourceInfo()); then fail();
end match;
end evalBuiltinSymmetric;

function evalBuiltinTanh
input Expression arg;
output Expression result;
Expand Down
4 changes: 2 additions & 2 deletions Compiler/NFFrontEnd/NFComponentRef.mo
Expand Up @@ -697,13 +697,13 @@ public
local
list<Subscript> subs;

case CREF(subscripts = {})
case CREF(subscripts = {}, origin = Origin.CREF)
algorithm
cref.restCref := simplifySubscripts(cref.restCref);
then
cref;

case CREF()
case CREF(origin = Origin.CREF)
algorithm
subs := list(Subscript.simplify(s) for s in cref.subscripts);
then
Expand Down
35 changes: 21 additions & 14 deletions Compiler/NFFrontEnd/NFConvertDAE.mo
Expand Up @@ -805,23 +805,30 @@ function convertIfStatement
input DAE.ElementSource source;
output DAE.Statement ifStatement;
protected
DAE.Exp cond1, cond2;
list<DAE.Statement> stmts1, stmts2;
tuple<Expression, list<Statement>> head;
list<tuple<Expression, list<Statement>>> rest;
DAE.Else elseStatement = DAE.Else.NOELSE();
Expression cond;
DAE.Exp dcond;
list<Statement> stmts;
list<DAE.Statement> dstmts;
Boolean first = true;
DAE.Else else_stmt = DAE.Else.NOELSE();
algorithm
head :: rest := ifBranches;
cond1 := Expression.toDAE(Util.tuple21(head));
stmts1 := convertStatements(Util.tuple22(head));

for b in listReverse(rest) loop
cond2 := Expression.toDAE(Util.tuple21(b));
stmts2 := convertStatements(Util.tuple22(b));
elseStatement := DAE.Else.ELSEIF(cond2, stmts2, elseStatement);
for b in listReverse(ifBranches) loop
(cond, stmts) := b;
dcond := Expression.toDAE(cond);
dstmts := convertStatements(stmts);

if first and Expression.isTrue(cond) then
else_stmt := DAE.Else.ELSE(dstmts);
else
else_stmt := DAE.Else.ELSEIF(dcond, dstmts, else_stmt);
end if;

first := false;
end for;

ifStatement := DAE.Statement.STMT_IF(cond1, stmts1, elseStatement, source);
// This should always be an ELSEIF due to branch selection in earlier phases.
DAE.Else.ELSEIF(dcond, dstmts, else_stmt) := else_stmt;
ifStatement := DAE.Statement.STMT_IF(dcond, dstmts, else_stmt, source);
end convertIfStatement;

function convertWhenStatement
Expand Down
9 changes: 9 additions & 0 deletions Compiler/NFFrontEnd/NFEquation.mo
Expand Up @@ -112,6 +112,15 @@ public
DAE.ElementSource source;
end NORETCALL;

function makeIf
input list<tuple<Expression, list<Equation>>> branches;
input DAE.ElementSource src;
output Equation eq;
algorithm
eq := IF(branches, src);
annotation(__OpenModelica_EarlyInline=true);
end makeIf;

function source
input Equation eq;
output DAE.ElementSource source;
Expand Down
38 changes: 32 additions & 6 deletions Compiler/NFFrontEnd/NFExpression.mo
Expand Up @@ -687,7 +687,7 @@ public
Type exp_ty;
ComponentRef cref;
algorithm
is_scalar_const := isScalarConst(indexExp);
is_scalar_const := isScalarLiteral(indexExp);

// check exp has array type. Don't apply subs to scalar exp.
exp_ty := typeOf(exp);
Expand Down Expand Up @@ -2925,7 +2925,7 @@ public
input Expression exp2;
output Expression exp;
algorithm
if isScalarConst(exp1) and isScalarConst(exp2) then
if isScalarLiteral(exp1) and isScalarLiteral(exp2) then
exp := Ceval.evalBinaryOp(exp1, op, exp2);
else
exp := BINARY(exp1, op, exp2);
Expand Down Expand Up @@ -3116,19 +3116,35 @@ public
end match;
end isZero;

function isScalarConst
function isScalarLiteral
input Expression exp;
output Boolean isScalar;
output Boolean literal;
algorithm
isScalar := match exp
literal := match exp
case INTEGER() then true;
case REAL() then true;
case STRING() then true;
case BOOLEAN() then true;
case ENUM_LITERAL() then true;
else false;
end match;
end isScalarConst;
end isScalarLiteral;

function isLiteral
input Expression exp;
output Boolean literal;
algorithm
literal := match exp
case INTEGER() then true;
case REAL() then true;
case STRING() then true;
case BOOLEAN() then true;
case ENUM_LITERAL() then true;
case ARRAY() then List.all(exp.elements, isLiteral);
case RECORD() then List.all(exp.elements, isLiteral);
else false;
end match;
end isLiteral;

function isInteger
input Expression exp;
Expand Down Expand Up @@ -3539,5 +3555,15 @@ public
end match;
end enumIndexExp;

function toScalar
input Expression exp;
output Expression outExp;
algorithm
outExp := match exp
case ARRAY(elements = {outExp}) then toScalar(outExp);
else exp;
end match;
end toScalar;

annotation(__OpenModelica_Interface="frontend");
end NFExpression;
3 changes: 2 additions & 1 deletion Compiler/NFFrontEnd/NFFunction.mo
Expand Up @@ -1098,7 +1098,8 @@ uniontype Function

function isExternal
input Function fn;
output Boolean isExternal = Class.isExternalFunction(InstNode.getClass(fn.node));
output Boolean isExternal = not InstNode.isEmpty(fn.node) and
Class.isExternalFunction(InstNode.getClass(fn.node));
end isExternal;

function inlineBuiltin
Expand Down

0 comments on commit c097132

Please sign in to comment.