Skip to content
This repository was archived by the owner on May 18, 2019. It is now read-only.

Commit 6da90a0

Browse files
perostOpenModelica-Hudson
authored andcommitted
[NF] Generate fold expression for reductions.
- Fill in the fold expression in the DAE.REDUCTIONINFO record when converting reduction to DAE form. Belonging to [master]: - #2764
1 parent 2726ae4 commit 6da90a0

File tree

2 files changed

+92
-22
lines changed

2 files changed

+92
-22
lines changed

Compiler/NFFrontEnd/NFCall.mo

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ import TypeCheck = NFTypeCheck;
6565
import Typing = NFTyping;
6666
import Util;
6767
import Subscript = NFSubscript;
68+
import Operator = NFOperator;
6869

6970
public
7071
uniontype CallAttributes
@@ -677,37 +678,48 @@ uniontype Call
677678
output DAE.Exp daeCall;
678679
algorithm
679680
daeCall := match call
681+
local
682+
String fold_id, res_id;
683+
680684
case TYPED_CALL()
681685
then DAE.CALL(
682686
Function.nameConsiderBuiltin(call.fn),
683687
list(Expression.toDAE(e) for e in call.arguments),
684688
CallAttributes.toDAE(call.attributes));
685689

686690
case TYPED_ARRAY_CONSTRUCTOR()
687-
then DAE.REDUCTION(
688-
DAE.REDUCTIONINFO(
689-
Function.name(NFBuiltinFuncs.ARRAY_FUNC),
690-
Absyn.COMBINE(),
691-
Type.toDAE(call.ty),
692-
NONE(),
693-
String(Util.getTempVariableIndex()),
694-
String(Util.getTempVariableIndex()),
695-
NONE()),
696-
Expression.toDAE(call.exp),
697-
list(iteratorToDAE(iter) for iter in call.iters));
691+
algorithm
692+
fold_id := Util.getTempVariableIndex();
693+
res_id := Util.getTempVariableIndex();
694+
then
695+
DAE.REDUCTION(
696+
DAE.REDUCTIONINFO(
697+
Function.name(NFBuiltinFuncs.ARRAY_FUNC),
698+
Absyn.COMBINE(),
699+
Type.toDAE(call.ty),
700+
NONE(),
701+
fold_id,
702+
res_id,
703+
NONE()),
704+
Expression.toDAE(call.exp),
705+
list(iteratorToDAE(iter) for iter in call.iters));
698706

699707
case TYPED_REDUCTION()
700-
then DAE.REDUCTION(
701-
DAE.REDUCTIONINFO(
702-
Function.name(call.fn),
703-
Absyn.COMBINE(),
704-
Type.toDAE(call.ty),
705-
SOME(Expression.toDAEValue(reductionDefaultValue(call))),
706-
String(Util.getTempVariableIndex()),
707-
String(Util.getTempVariableIndex()),
708-
NONE()),
709-
Expression.toDAE(call.exp),
710-
list(iteratorToDAE(iter) for iter in call.iters));
708+
algorithm
709+
fold_id := Util.getTempVariableIndex();
710+
res_id := Util.getTempVariableIndex();
711+
then
712+
DAE.REDUCTION(
713+
DAE.REDUCTIONINFO(
714+
Function.name(call.fn),
715+
Absyn.COMBINE(),
716+
Type.toDAE(call.ty),
717+
SOME(Expression.toDAEValue(reductionDefaultValue(call))),
718+
fold_id,
719+
res_id,
720+
Expression.toDAEOpt(reductionFoldExpression(call.fn, call.ty, call.var, fold_id, res_id))),
721+
Expression.toDAE(call.exp),
722+
list(iteratorToDAE(iter) for iter in call.iters));
711723

712724
else
713725
algorithm
@@ -734,6 +746,51 @@ uniontype Call
734746
end match;
735747
end reductionDefaultValue;
736748

749+
function reductionFoldExpression
750+
input Function reductionFn;
751+
input Type reductionType;
752+
input Variability reductionVar;
753+
input String foldId;
754+
input String resultId;
755+
output Option<Expression> foldExp;
756+
protected
757+
Type ty;
758+
algorithm
759+
foldExp := match Absyn.pathFirstIdent(Function.name(reductionFn))
760+
case "sum"
761+
then SOME(Expression.BINARY(
762+
reductionFoldIterator(resultId, reductionType),
763+
Operator.makeAdd(reductionType),
764+
reductionFoldIterator(foldId, reductionType)));
765+
766+
case "product"
767+
then SOME(Expression.BINARY(
768+
reductionFoldIterator(resultId, reductionType),
769+
Operator.makeMul(reductionType),
770+
reductionFoldIterator(foldId, reductionType)));
771+
772+
case "$array" then NONE();
773+
case "array" then NONE();
774+
case "list" then NONE();
775+
case "listReverse" then NONE();
776+
777+
else
778+
SOME(Expression.CALL(Call.makeTypedCall(reductionFn,
779+
{reductionFoldIterator(foldId, reductionType),
780+
reductionFoldIterator(resultId, reductionType)},
781+
reductionVar, reductionType)));
782+
783+
end match;
784+
end reductionFoldExpression;
785+
786+
function reductionFoldIterator
787+
input String name;
788+
input Type ty;
789+
output Expression iterExp;
790+
algorithm
791+
iterExp := Expression.CREF(ty, ComponentRef.makeIterator(InstNode.NAME_NODE(name), ty));
792+
end reductionFoldIterator;
793+
737794
function isVectorizeable
738795
input Call call;
739796
output Boolean isVect;

Compiler/NFFrontEnd/NFExpression.mo

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,19 @@ public
14311431
end match;
14321432
end isAssociativeExp;
14331433

1434+
function toDAEOpt
1435+
input Option<Expression> exp;
1436+
output Option<DAE.Exp> dexp;
1437+
algorithm
1438+
dexp := match exp
1439+
local
1440+
Expression e;
1441+
1442+
case SOME(e) then SOME(toDAE(e));
1443+
else NONE();
1444+
end match;
1445+
end toDAEOpt;
1446+
14341447
function toDAE
14351448
input Expression exp;
14361449
output DAE.Exp dexp;

0 commit comments

Comments
 (0)