Skip to content

Commit

Permalink
[NF] Implemented basic function inlining.
Browse files Browse the repository at this point in the history
- Implemented early inlining of functions with one output and one
  statement.
- Removed Ceval.evalBuiltinCross, the definition in ModelicaBuiltin
  is used instead.
- Updated Expression map and fold function to also consider reduction
  iterators.

Belonging to [master]:
  - OpenModelica/OMCompiler#2412
  - OpenModelica/OpenModelica-testsuite#938
  • Loading branch information
perost authored and OpenModelica-Hudson committed May 7, 2018
1 parent 87286cb commit 5feb832
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 58 deletions.
13 changes: 13 additions & 0 deletions Compiler/NFFrontEnd/NFCall.mo
Expand Up @@ -67,6 +67,7 @@ import NFFunction.MatchedFunction;
import Ceval = NFCeval;
import SimplifyExp = NFSimplifyExp;
import Subscript = NFSubscript;
import Inline = NFInline;

public
uniontype CallAttributes
Expand Down Expand Up @@ -453,6 +454,7 @@ uniontype Call
outExp := toRecordExpression(call, ty);
else
outExp := Expression.CALL(call);
outExp := Inline.inlineCallExp(outExp);
end if;
end if;
then
Expand Down Expand Up @@ -2676,6 +2678,17 @@ protected
then Expression.RECORD(Absyn.stripLast(Function.name(call.fn)), ty, call.arguments);
end match;
end toRecordExpression;

function inlineType
input Call call;
output DAE.InlineType inlineTy;
algorithm
inlineTy := match call
case TYPED_CALL(attributes = CallAttributes.CALL_ATTR(inlineType = inlineTy))
then inlineTy;
else DAE.InlineType.NO_INLINE();
end match;
end inlineType;
end Call;

annotation(__OpenModelica_Interface="frontend");
Expand Down
24 changes: 1 addition & 23 deletions Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -1139,7 +1139,7 @@ algorithm
else
evalNormalCall(call.fn, args, call);

case Call.UNTYPED_MAP_CALL()
case Call.TYPED_MAP_CALL()
algorithm
Error.addInternalError(getInstanceName() + ": unimplemented case for mapcall", sourceInfo());
then
Expand Down Expand Up @@ -1173,7 +1173,6 @@ algorithm
case "ceil" then evalBuiltinCeil(listHead(args));
case "cosh" then evalBuiltinCosh(listHead(args));
case "cos" then evalBuiltinCos(listHead(args));
case "cross" then evalBuiltinCross(args);
case "der" then evalBuiltinDer(listHead(args));
// TODO: Fix typing of diagonal so the argument isn't boxed.
case "diagonal" then evalBuiltinDiagonal(Expression.unbox(listHead(args)));
Expand Down Expand Up @@ -1418,27 +1417,6 @@ algorithm
end match;
end evalBuiltinCos;

function evalBuiltinCross
input list<Expression> args;
output Expression result;
protected
Real x1, x2, x3, y1, y2, y3;
Expression z1, z2, z3;
algorithm
result := match args
case {Expression.ARRAY(elements = {Expression.REAL(x1), Expression.REAL(x2), Expression.REAL(x3)}),
Expression.ARRAY(elements = {Expression.REAL(y1), Expression.REAL(y2), Expression.REAL(y3)})}
algorithm
z1 := Expression.REAL(x2 * y3 - x3 * y2);
z2 := Expression.REAL(x3 * y1 - x1 * y3);
z3 := Expression.REAL(x1 * y2 - x2 * y1);
then
Expression.ARRAY(Type.ARRAY(Type.REAL(), {Dimension.fromInteger(3)}), {z1, z2, z3});

else algorithm printWrongArgsError(getInstanceName(), args, sourceInfo()); then fail();
end match;
end evalBuiltinCross;

function evalBuiltinDer
input Expression arg;
output Expression result;
Expand Down
28 changes: 1 addition & 27 deletions Compiler/NFFrontEnd/NFEvalFunction.mo
Expand Up @@ -119,7 +119,7 @@ algorithm
Pointer.update(call_counter, call_count);

try
fn_body := getFunctionBody(fn.node);
fn_body := Function.getBody(fn);
repl := createReplacements(fn, args);
// TODO: Also apply replacements to the replacements themselves, i.e. the
// bindings of the function parameters. But the probably need to be
Expand Down Expand Up @@ -165,32 +165,6 @@ end evaluateExternal;

protected

function getFunctionBody
input InstNode node;
output list<Statement> body;
protected
Class cls = InstNode.getClass(node);
algorithm
body := match cls
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = {body})) then body;

case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = _ :: _))
algorithm
Error.assertion(false, getInstanceName() + " got function with multiple algorithm sections", sourceInfo());
then
fail();

case Class.TYPED_DERIVED() then getFunctionBody(cls.baseClass);

else
algorithm
Error.assertion(false, getInstanceName() + " got unknown function", sourceInfo());
then
fail();

end match;
end getFunctionBody;

function createReplacements
input Function fn;
input list<Expression> args;
Expand Down
82 changes: 78 additions & 4 deletions Compiler/NFFrontEnd/NFExpression.mo
Expand Up @@ -770,8 +770,6 @@ public
input output Expression exp;
input InstNode iterator;
input Expression iteratorValue;

import Origin = NFComponentRef.Origin;
algorithm
exp := match exp
local
Expand Down Expand Up @@ -1315,6 +1313,7 @@ public
Expression e;
Type t;
Variability v;
list<tuple<InstNode, Expression>> iters;

case Call.UNTYPED_CALL()
algorithm
Expand Down Expand Up @@ -1357,18 +1356,41 @@ public
case Call.UNTYPED_MAP_CALL()
algorithm
e := map(call.exp, func);
iters := mapCallIterators(call.iters, func);
then
Call.UNTYPED_MAP_CALL(e, call.iters);
Call.UNTYPED_MAP_CALL(e, iters);

case Call.TYPED_MAP_CALL()
algorithm
e := map(call.exp, func);
iters := mapCallIterators(call.iters, func);
then
Call.TYPED_MAP_CALL(call.ty, call.var, e, call.iters);
Call.TYPED_MAP_CALL(call.ty, call.var, e, iters);

end match;
end mapCall;

function mapCallIterators
input list<tuple<InstNode, Expression>> iters;
input MapFunc func;
output list<tuple<InstNode, Expression>> outIters = {};

partial function MapFunc
input output Expression e;
end MapFunc;
protected
InstNode node;
Expression exp, new_exp;
algorithm
for i in iters loop
(node, exp) := i;
new_exp := map(exp, func);
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
end for;

outIters := listReverseInPlace(outIters);
end mapCallIterators;

function mapCref
input ComponentRef cref;
input MapFunc func;
Expand Down Expand Up @@ -1656,6 +1678,27 @@ public
end match;
end mapCallShallow;

function mapCallShallowIterators
input list<tuple<InstNode, Expression>> iters;
input MapFunc func;
output list<tuple<InstNode, Expression>> outIters = {};

partial function MapFunc
input output Expression e;
end MapFunc;
protected
InstNode node;
Expression exp, new_exp;
algorithm
for i in iters loop
(node, exp) := i;
new_exp := func(exp);
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
end for;

outIters := listReverseInPlace(outIters);
end mapCallShallowIterators;

function mapArrayElements
"Applies the given function to each scalar elements of an array."
input Expression exp;
Expand Down Expand Up @@ -1839,12 +1882,20 @@ public
case Call.UNTYPED_MAP_CALL()
algorithm
foldArg := fold(call.exp, func, foldArg);

for i in call.iters loop
foldArg := fold(Util.tuple22(i), func, foldArg);
end for;
then
();

case Call.TYPED_MAP_CALL()
algorithm
foldArg := fold(call.exp, func, foldArg);

for i in call.iters loop
foldArg := fold(Util.tuple22(i), func, foldArg);
end for;
then
();

Expand Down Expand Up @@ -2141,6 +2192,29 @@ public
end match;
end mapFoldCall;

function mapFoldCallIterators<ArgT>
input list<tuple<InstNode, Expression>> iters;
input MapFunc func;
output list<tuple<InstNode, Expression>> outIters = {};
input output ArgT arg;

partial function MapFunc
input output Expression e;
input output ArgT arg;
end MapFunc;
protected
InstNode node;
Expression exp, new_exp;
algorithm
for i in iters loop
(node, exp) := i;
(new_exp, arg) := mapFold(exp, func, arg);
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
end for;

outIters := listReverseInPlace(outIters);
end mapFoldCallIterators;

function mapFoldCref<ArgT>
input ComponentRef cref;
input MapFunc func;
Expand Down
34 changes: 34 additions & 0 deletions Compiler/NFFrontEnd/NFFunction.mo
Expand Up @@ -66,6 +66,8 @@ import MatchKind = NFTypeCheck.MatchKind;
import Restriction = NFRestriction;
import NFTyping.ExpOrigin;
import Dimension = NFDimension;
import Statement = NFStatement;
import Sections = NFSections;


public
Expand Down Expand Up @@ -1155,6 +1157,11 @@ uniontype Function
ty := DAE.T_FUNCTION(params, Type.toDAE(fn.returnType), fn.attributes, fn.path);
end makeDAEType;

function getBody
input Function fn;
output list<Statement> body = getBody2(fn.node);
end getBody;

protected
function collectParams
"Sorts all the function parameters as inputs, outputs and locals."
Expand Down Expand Up @@ -1453,6 +1460,33 @@ protected
else Type.TUPLE(ret_tyl, NONE());
end match;
end makeReturnType;

function getBody2
input InstNode node;
output list<Statement> body;
protected
Class cls = InstNode.getClass(node);
algorithm
body := match cls
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = {body})) then body;
case Class.INSTANCED_CLASS(sections = Sections.EMPTY()) then {};

case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = _ :: _))
algorithm
Error.assertion(false, getInstanceName() + " got function with multiple algorithm sections", sourceInfo());
then
fail();

case Class.TYPED_DERIVED() then getBody2(cls.baseClass);

else
algorithm
Error.assertion(false, getInstanceName() + " got unknown function", sourceInfo());
then
fail();

end match;
end getBody2;
end Function;

annotation(__OpenModelica_Interface="frontend");
Expand Down

0 comments on commit 5feb832

Please sign in to comment.