Skip to content
This repository was archived by the owner on May 18, 2019. It is now read-only.

Commit 5feb832

Browse files
perostOpenModelica-Hudson
authored andcommitted
[NF] Implemented basic function inlining.
- Implemented early inlining of functions with one output and one statement. - Removed Ceval.evalBuiltinCross, the definition in ModelicaBuiltin is used instead. - Updated Expression map and fold function to also consider reduction iterators. Belonging to [master]: - #2412 - OpenModelica/OpenModelica-testsuite#938
1 parent 87286cb commit 5feb832

File tree

8 files changed

+284
-58
lines changed

8 files changed

+284
-58
lines changed

Compiler/NFFrontEnd/NFCall.mo

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import NFFunction.MatchedFunction;
6767
import Ceval = NFCeval;
6868
import SimplifyExp = NFSimplifyExp;
6969
import Subscript = NFSubscript;
70+
import Inline = NFInline;
7071

7172
public
7273
uniontype CallAttributes
@@ -453,6 +454,7 @@ uniontype Call
453454
outExp := toRecordExpression(call, ty);
454455
else
455456
outExp := Expression.CALL(call);
457+
outExp := Inline.inlineCallExp(outExp);
456458
end if;
457459
end if;
458460
then
@@ -2676,6 +2678,17 @@ protected
26762678
then Expression.RECORD(Absyn.stripLast(Function.name(call.fn)), ty, call.arguments);
26772679
end match;
26782680
end toRecordExpression;
2681+
2682+
function inlineType
2683+
input Call call;
2684+
output DAE.InlineType inlineTy;
2685+
algorithm
2686+
inlineTy := match call
2687+
case TYPED_CALL(attributes = CallAttributes.CALL_ATTR(inlineType = inlineTy))
2688+
then inlineTy;
2689+
else DAE.InlineType.NO_INLINE();
2690+
end match;
2691+
end inlineType;
26792692
end Call;
26802693

26812694
annotation(__OpenModelica_Interface="frontend");

Compiler/NFFrontEnd/NFCeval.mo

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ algorithm
11391139
else
11401140
evalNormalCall(call.fn, args, call);
11411141

1142-
case Call.UNTYPED_MAP_CALL()
1142+
case Call.TYPED_MAP_CALL()
11431143
algorithm
11441144
Error.addInternalError(getInstanceName() + ": unimplemented case for mapcall", sourceInfo());
11451145
then
@@ -1173,7 +1173,6 @@ algorithm
11731173
case "ceil" then evalBuiltinCeil(listHead(args));
11741174
case "cosh" then evalBuiltinCosh(listHead(args));
11751175
case "cos" then evalBuiltinCos(listHead(args));
1176-
case "cross" then evalBuiltinCross(args);
11771176
case "der" then evalBuiltinDer(listHead(args));
11781177
// TODO: Fix typing of diagonal so the argument isn't boxed.
11791178
case "diagonal" then evalBuiltinDiagonal(Expression.unbox(listHead(args)));
@@ -1418,27 +1417,6 @@ algorithm
14181417
end match;
14191418
end evalBuiltinCos;
14201419

1421-
function evalBuiltinCross
1422-
input list<Expression> args;
1423-
output Expression result;
1424-
protected
1425-
Real x1, x2, x3, y1, y2, y3;
1426-
Expression z1, z2, z3;
1427-
algorithm
1428-
result := match args
1429-
case {Expression.ARRAY(elements = {Expression.REAL(x1), Expression.REAL(x2), Expression.REAL(x3)}),
1430-
Expression.ARRAY(elements = {Expression.REAL(y1), Expression.REAL(y2), Expression.REAL(y3)})}
1431-
algorithm
1432-
z1 := Expression.REAL(x2 * y3 - x3 * y2);
1433-
z2 := Expression.REAL(x3 * y1 - x1 * y3);
1434-
z3 := Expression.REAL(x1 * y2 - x2 * y1);
1435-
then
1436-
Expression.ARRAY(Type.ARRAY(Type.REAL(), {Dimension.fromInteger(3)}), {z1, z2, z3});
1437-
1438-
else algorithm printWrongArgsError(getInstanceName(), args, sourceInfo()); then fail();
1439-
end match;
1440-
end evalBuiltinCross;
1441-
14421420
function evalBuiltinDer
14431421
input Expression arg;
14441422
output Expression result;

Compiler/NFFrontEnd/NFEvalFunction.mo

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ algorithm
119119
Pointer.update(call_counter, call_count);
120120

121121
try
122-
fn_body := getFunctionBody(fn.node);
122+
fn_body := Function.getBody(fn);
123123
repl := createReplacements(fn, args);
124124
// TODO: Also apply replacements to the replacements themselves, i.e. the
125125
// bindings of the function parameters. But the probably need to be
@@ -165,32 +165,6 @@ end evaluateExternal;
165165

166166
protected
167167

168-
function getFunctionBody
169-
input InstNode node;
170-
output list<Statement> body;
171-
protected
172-
Class cls = InstNode.getClass(node);
173-
algorithm
174-
body := match cls
175-
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = {body})) then body;
176-
177-
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = _ :: _))
178-
algorithm
179-
Error.assertion(false, getInstanceName() + " got function with multiple algorithm sections", sourceInfo());
180-
then
181-
fail();
182-
183-
case Class.TYPED_DERIVED() then getFunctionBody(cls.baseClass);
184-
185-
else
186-
algorithm
187-
Error.assertion(false, getInstanceName() + " got unknown function", sourceInfo());
188-
then
189-
fail();
190-
191-
end match;
192-
end getFunctionBody;
193-
194168
function createReplacements
195169
input Function fn;
196170
input list<Expression> args;

Compiler/NFFrontEnd/NFExpression.mo

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,6 @@ public
770770
input output Expression exp;
771771
input InstNode iterator;
772772
input Expression iteratorValue;
773-
774-
import Origin = NFComponentRef.Origin;
775773
algorithm
776774
exp := match exp
777775
local
@@ -1315,6 +1313,7 @@ public
13151313
Expression e;
13161314
Type t;
13171315
Variability v;
1316+
list<tuple<InstNode, Expression>> iters;
13181317

13191318
case Call.UNTYPED_CALL()
13201319
algorithm
@@ -1357,18 +1356,41 @@ public
13571356
case Call.UNTYPED_MAP_CALL()
13581357
algorithm
13591358
e := map(call.exp, func);
1359+
iters := mapCallIterators(call.iters, func);
13601360
then
1361-
Call.UNTYPED_MAP_CALL(e, call.iters);
1361+
Call.UNTYPED_MAP_CALL(e, iters);
13621362

13631363
case Call.TYPED_MAP_CALL()
13641364
algorithm
13651365
e := map(call.exp, func);
1366+
iters := mapCallIterators(call.iters, func);
13661367
then
1367-
Call.TYPED_MAP_CALL(call.ty, call.var, e, call.iters);
1368+
Call.TYPED_MAP_CALL(call.ty, call.var, e, iters);
13681369

13691370
end match;
13701371
end mapCall;
13711372

1373+
function mapCallIterators
1374+
input list<tuple<InstNode, Expression>> iters;
1375+
input MapFunc func;
1376+
output list<tuple<InstNode, Expression>> outIters = {};
1377+
1378+
partial function MapFunc
1379+
input output Expression e;
1380+
end MapFunc;
1381+
protected
1382+
InstNode node;
1383+
Expression exp, new_exp;
1384+
algorithm
1385+
for i in iters loop
1386+
(node, exp) := i;
1387+
new_exp := map(exp, func);
1388+
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
1389+
end for;
1390+
1391+
outIters := listReverseInPlace(outIters);
1392+
end mapCallIterators;
1393+
13721394
function mapCref
13731395
input ComponentRef cref;
13741396
input MapFunc func;
@@ -1656,6 +1678,27 @@ public
16561678
end match;
16571679
end mapCallShallow;
16581680

1681+
function mapCallShallowIterators
1682+
input list<tuple<InstNode, Expression>> iters;
1683+
input MapFunc func;
1684+
output list<tuple<InstNode, Expression>> outIters = {};
1685+
1686+
partial function MapFunc
1687+
input output Expression e;
1688+
end MapFunc;
1689+
protected
1690+
InstNode node;
1691+
Expression exp, new_exp;
1692+
algorithm
1693+
for i in iters loop
1694+
(node, exp) := i;
1695+
new_exp := func(exp);
1696+
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
1697+
end for;
1698+
1699+
outIters := listReverseInPlace(outIters);
1700+
end mapCallShallowIterators;
1701+
16591702
function mapArrayElements
16601703
"Applies the given function to each scalar elements of an array."
16611704
input Expression exp;
@@ -1839,12 +1882,20 @@ public
18391882
case Call.UNTYPED_MAP_CALL()
18401883
algorithm
18411884
foldArg := fold(call.exp, func, foldArg);
1885+
1886+
for i in call.iters loop
1887+
foldArg := fold(Util.tuple22(i), func, foldArg);
1888+
end for;
18421889
then
18431890
();
18441891

18451892
case Call.TYPED_MAP_CALL()
18461893
algorithm
18471894
foldArg := fold(call.exp, func, foldArg);
1895+
1896+
for i in call.iters loop
1897+
foldArg := fold(Util.tuple22(i), func, foldArg);
1898+
end for;
18481899
then
18491900
();
18501901

@@ -2141,6 +2192,29 @@ public
21412192
end match;
21422193
end mapFoldCall;
21432194

2195+
function mapFoldCallIterators<ArgT>
2196+
input list<tuple<InstNode, Expression>> iters;
2197+
input MapFunc func;
2198+
output list<tuple<InstNode, Expression>> outIters = {};
2199+
input output ArgT arg;
2200+
2201+
partial function MapFunc
2202+
input output Expression e;
2203+
input output ArgT arg;
2204+
end MapFunc;
2205+
protected
2206+
InstNode node;
2207+
Expression exp, new_exp;
2208+
algorithm
2209+
for i in iters loop
2210+
(node, exp) := i;
2211+
(new_exp, arg) := mapFold(exp, func, arg);
2212+
outIters := (if referenceEq(new_exp, exp) then i else (node, new_exp)) :: outIters;
2213+
end for;
2214+
2215+
outIters := listReverseInPlace(outIters);
2216+
end mapFoldCallIterators;
2217+
21442218
function mapFoldCref<ArgT>
21452219
input ComponentRef cref;
21462220
input MapFunc func;

Compiler/NFFrontEnd/NFFunction.mo

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ import MatchKind = NFTypeCheck.MatchKind;
6666
import Restriction = NFRestriction;
6767
import NFTyping.ExpOrigin;
6868
import Dimension = NFDimension;
69+
import Statement = NFStatement;
70+
import Sections = NFSections;
6971

7072

7173
public
@@ -1155,6 +1157,11 @@ uniontype Function
11551157
ty := DAE.T_FUNCTION(params, Type.toDAE(fn.returnType), fn.attributes, fn.path);
11561158
end makeDAEType;
11571159

1160+
function getBody
1161+
input Function fn;
1162+
output list<Statement> body = getBody2(fn.node);
1163+
end getBody;
1164+
11581165
protected
11591166
function collectParams
11601167
"Sorts all the function parameters as inputs, outputs and locals."
@@ -1453,6 +1460,33 @@ protected
14531460
else Type.TUPLE(ret_tyl, NONE());
14541461
end match;
14551462
end makeReturnType;
1463+
1464+
function getBody2
1465+
input InstNode node;
1466+
output list<Statement> body;
1467+
protected
1468+
Class cls = InstNode.getClass(node);
1469+
algorithm
1470+
body := match cls
1471+
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = {body})) then body;
1472+
case Class.INSTANCED_CLASS(sections = Sections.EMPTY()) then {};
1473+
1474+
case Class.INSTANCED_CLASS(sections = Sections.SECTIONS(algorithms = _ :: _))
1475+
algorithm
1476+
Error.assertion(false, getInstanceName() + " got function with multiple algorithm sections", sourceInfo());
1477+
then
1478+
fail();
1479+
1480+
case Class.TYPED_DERIVED() then getBody2(cls.baseClass);
1481+
1482+
else
1483+
algorithm
1484+
Error.assertion(false, getInstanceName() + " got unknown function", sourceInfo());
1485+
then
1486+
fail();
1487+
1488+
end match;
1489+
end getBody2;
14561490
end Function;
14571491

14581492
annotation(__OpenModelica_Interface="frontend");

0 commit comments

Comments
 (0)