Skip to content
This repository has been archived by the owner on May 18, 2019. It is now read-only.

Commit

Permalink
Handle tail recursion in CevalFunction
Browse files Browse the repository at this point in the history
This fixes ticket:4838

Belonging to [master]:
  - #2300
  - OpenModelica/OpenModelica-testsuite#889
  • Loading branch information
sjoelund authored and OpenModelica-Hudson committed Mar 21, 2018
1 parent 748b75a commit 05a951b
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 20 deletions.
29 changes: 25 additions & 4 deletions Compiler/FrontEnd/CevalFunction.mo
Expand Up @@ -1031,12 +1031,16 @@ algorithm
FCore.Graph env;
DAE.Exp lhs, rhs, condition;
DAE.ComponentRef lhs_cref;
Values.Value rhs_val;
Values.Value rhs_val, v;
list<DAE.Exp> exps;
list<Values.Value> vals;
list<DAE.Statement> statements;
Absyn.Path path;
DAE.Type returnType;
LoopControl loop_ctrl;
DAE.TailCall tailCall;
String var;
list<String> vars;

case (DAE.STMT_ASSIGN(exp1 = lhs, exp = rhs), cache, env)
equation
Expand Down Expand Up @@ -1100,12 +1104,29 @@ algorithm
then
(cache, inEnv, NEXT());
// Special case for print, and other known calls for now; evaluated even when there is no ST
case (DAE.STMT_NORETCALL(exp = rhs as DAE.CALL(path = path, expLst = exps)), _, _)
case (DAE.STMT_NORETCALL(exp = rhs as DAE.CALL(path = path, expLst = exps, attr=DAE.CALL_ATTR(ty=returnType, tailCall=tailCall))), _, _)
algorithm
(cache, vals) := cevalExpList(exps, inCache, inEnv);
(cache, _) := cevalExp(rhs, cache, inEnv);
(cache, v) := cevalExp(rhs, cache, inEnv);
(cache, env, outLoopControl) := match tailCall
case DAE.NO_TAIL() then (cache, inEnv, NEXT());
// Handle tail recursion; same as a assigning all outputs followed by a return
case DAE.TAIL(outVars={}) then (cache, inEnv, RETURN());
case DAE.TAIL(outVars={var})
algorithm
(cache, env) := assignVariable(ComponentReference.makeUntypedCrefIdent(var), v, cache, inEnv);
then (cache, env, RETURN());
case DAE.TAIL(outVars=vars)
algorithm
Values.TUPLE(vals) := v;
for val in vals loop
var::vars := vars;
(cache, env) := assignVariable(ComponentReference.makeUntypedCrefIdent(var), val, cache, inEnv);
end for;
then (cache, env, RETURN());
end match;
then
(cache, inEnv, NEXT());
(cache, env, NEXT());

case (DAE.STMT_RETURN(), _, _)
then
Expand Down
1 change: 1 addition & 0 deletions Compiler/FrontEnd/DAE.mo
Expand Up @@ -1534,6 +1534,7 @@ public uniontype TailCall
end NO_TAIL;
record TAIL
list<String> vars;
list<String> outVars;
end TAIL;
end TailCall;

Expand Down
17 changes: 14 additions & 3 deletions Compiler/FrontEnd/DAEDump.mo
Expand Up @@ -1462,8 +1462,15 @@ algorithm
();

case (DAE.STMT_NORETCALL(exp = e1),i)
equation
algorithm
indent(i);
_ := match e1
case DAE.CALL(attr=DAE.CALL_ATTR(tailCall=DAE.TAIL()))
algorithm
Print.printBuf("return ");
then ();
else ();
end match;
ExpressionDump.printExp(e1);
Print.printBuf(";\n");
then
Expand Down Expand Up @@ -1716,8 +1723,12 @@ algorithm
case (DAE.STMT_NORETCALL(exp = e),i)
equation
s1 = indentStr(i);
s2 = ExpressionDump.printExpStr(e);
str = stringAppendList({s1,s2,";\n"});
s2 = match e
case DAE.CALL(attr=DAE.CALL_ATTR(tailCall=DAE.TAIL())) then "return ";
else "";
end match;
s3 = ExpressionDump.printExpStr(e);
str = stringAppendList({s1,s2,s3,";\n"});
then
str;

Expand Down
24 changes: 12 additions & 12 deletions Compiler/FrontEnd/InstUtil.mo
Expand Up @@ -7551,13 +7551,13 @@ protected function optimizeStatementTail2
output DAE.Exp orhs;
algorithm
true:=valueEq(lhsVars,outvars);
(orhs,true) := optimizeStatementTail3(path,rhs,invars,source);
(orhs,true) := optimizeStatementTail3(path,rhs,invars,lhsVars,source);
end optimizeStatementTail2;

protected function optimizeStatementTail3
input Absyn.Path path;
input DAE.Exp rhs;
input list<String> vars;
input list<String> vars, lhsVars;
input DAE.ElementSource source;
output DAE.Exp orhs;
output Boolean isTailRecursive;
Expand Down Expand Up @@ -7585,18 +7585,18 @@ algorithm
if Flags.isSet(Flags.TAIL) then
Error.addSourceMessage(Error.COMPILER_NOTIFICATION,{str},ElementSource.getElementSourceFileInfo(source));
end if;
attr.tailCall = DAE.TAIL(vars);
attr.tailCall = DAE.TAIL(vars,lhsVars);
call.attr = attr;
then (call,true);
case (_,DAE.IFEXP(e1,e2,e3),_,_)
equation
(e2,b1) = optimizeStatementTail3(path,e2,vars,source);
(e3,b2) = optimizeStatementTail3(path,e3,vars,source);
(e2,b1) = optimizeStatementTail3(path,e2,vars,lhsVars,source);
(e3,b2) = optimizeStatementTail3(path,e3,vars,lhsVars,source);
true = b1 or b2;
then (DAE.IFEXP(e1,e2,e3),true);
case (_,DAE.MATCHEXPRESSION(matchType as DAE.MATCH(_) /*TODO:matchcontinue*/,inputs,aliases,localDecls,cases,et),_,_)
equation
cases = optimizeStatementTailMatchCases(path,cases,false,{},vars,source);
cases = optimizeStatementTailMatchCases(path,cases,false,{},vars,lhsVars,source);
then (DAE.MATCHEXPRESSION(matchType,inputs,aliases,localDecls,cases,et),true);
else (rhs,false);
end matchcontinue;
Expand All @@ -7607,7 +7607,7 @@ protected function optimizeStatementTailMatchCases
input list<DAE.MatchCase> inCases;
input Boolean changed;
input list<DAE.MatchCase> inAcc;
input list<String> vars;
input list<String> vars, lhsVars;
input DAE.ElementSource source;
output list<DAE.MatchCase> ocases;
algorithm
Expand All @@ -7627,18 +7627,18 @@ algorithm
case (_,{},true,acc,_,_) then listReverse(acc);
case (_,DAE.CASE(patterns,patternGuard,localDecls,body,SOME(exp),resultInfo,jump,info)::cases,_,acc,_,_)
equation
(exp,true) = optimizeStatementTail3(path,exp,vars,source);
(exp,true) = optimizeStatementTail3(path,exp,vars,lhsVars,source);
case_ = DAE.CASE(patterns,patternGuard,localDecls,body,SOME(exp),resultInfo,jump,info);
then optimizeStatementTailMatchCases(path,cases,true,case_::acc,vars,source);
then optimizeStatementTailMatchCases(path,cases,true,case_::acc,vars,lhsVars,source);
case (_,DAE.CASE(patterns,patternGuard,localDecls,body,SOME(DAE.TUPLE({})),resultInfo,jump,info)::cases,_,acc,_,_)
equation
DAE.STMT_NORETCALL(exp,sourceStmt) = List.last(body);
(exp,true) = optimizeStatementTail3(path,exp,vars,source);
(exp,true) = optimizeStatementTail3(path,exp,vars,lhsVars,source);
body = List.set(body,listLength(body),DAE.STMT_NORETCALL(exp,sourceStmt));
case_ = DAE.CASE(patterns,patternGuard,localDecls,body,SOME(DAE.TUPLE({})),resultInfo,jump,info);
then optimizeStatementTailMatchCases(path,cases,true,case_::acc,vars,source);
then optimizeStatementTailMatchCases(path,cases,true,case_::acc,vars,lhsVars,source);
case (_,case_::cases,_,acc,_,_)
then optimizeStatementTailMatchCases(path,cases,changed,case_::acc,vars,source);
then optimizeStatementTailMatchCases(path,cases,changed,case_::acc,vars,lhsVars,source);
end matchcontinue;
end optimizeStatementTailMatchCases;

Expand Down
14 changes: 14 additions & 0 deletions Compiler/Template/DAEDumpTV.mo
Expand Up @@ -1391,6 +1391,20 @@ package DAE
end SUM;

end Exp;

uniontype CallAttributes
record CALL_ATTR
TailCall tailCall "Input variables of the function if the call is tail-recursive";
end CALL_ATTR;
end CallAttributes;

uniontype TailCall
record NO_TAIL
end NO_TAIL;
record TAIL
end TAIL;
end TailCall;

end DAE;

package SCodeDump
Expand Down
5 changes: 4 additions & 1 deletion Compiler/Template/DAEDumpTpl.tpl
Expand Up @@ -698,8 +698,11 @@ template dumpNoRetCall(DAE.Exp call_exp, DAE.ElementSource src)
::=
let call_str = dumpExp(call_exp)
let src_str = dumpSource(src)
let tail_str = match call_exp
case CALL(attr=CALL_ATTR(tailCall=TAIL(__))) then "return "
else ""
<<
<%call_str%><%src_str%>;
<%tail_str%><%call_str%><%src_str%>;
>>
end dumpNoRetCall;

Expand Down

0 comments on commit 05a951b

Please sign in to comment.