Skip to content

Commit 1dfdf69

Browse files
authored
Propagate global context in EvalFunction (#14176)
- Propagate the global context in EvalFunction so that evaluation of expressions inside of functions follow the normal context rules (such as skipping external functions in the instance API).
1 parent 086bf0e commit 1dfdf69

File tree

1 file changed

+40
-21
lines changed

1 file changed

+40
-21
lines changed

OMCompiler/Compiler/NFFrontEnd/NFEvalFunction.mo

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@ algorithm
9191
// make sure we don't try to evaluate the non-differentiated function body.
9292
fail();
9393
else
94-
result := evaluateNormal(fn, args);
94+
result := evaluateNormal(fn, args, target.context);
9595
end if;
9696
end evaluate;
9797

9898
function evaluateNormal
9999
input Function fn;
100100
input list<Expression> args;
101+
input InstContext.Type context;
101102
output Expression result;
102103
protected
103104
list<Statement> fn_body;
@@ -106,6 +107,7 @@ protected
106107
Integer call_count, limit;
107108
Pointer<Integer> call_counter = fn.callCounter;
108109
FlowControl ctrl;
110+
InstContext.Type body_context;
109111
algorithm
110112
// Functions contain a mutable call counter that's increased by one at the
111113
// start of each evaluation, and decreased by one when the evalution is
@@ -122,6 +124,8 @@ algorithm
122124

123125
Pointer.update(call_counter, call_count);
124126

127+
body_context := InstContext.clearScopeFlags(context);
128+
125129
try
126130
fn_body := Function.getBody(fn);
127131
arg_map := createArgumentMap(fn.inputs, fn.outputs, fn.locals, args, mutableParams = true);
@@ -130,7 +134,7 @@ algorithm
130134
// sorted by dependencies first.
131135
fn_body := applyReplacements(arg_map, fn_body);
132136
fn_body := optimizeBody(fn_body);
133-
ctrl := evaluateStatements(fn_body);
137+
ctrl := evaluateStatements(fn_body, body_context);
134138

135139
if ctrl <> FlowControl.ASSERTION then
136140
result := createResult(arg_map, fn.outputs);
@@ -712,10 +716,11 @@ end assertAssignedOutput;
712716

713717
function evaluateStatements
714718
input list<Statement> stmts;
719+
input InstContext.Type context;
715720
output FlowControl ctrl = FlowControl.NEXT;
716721
algorithm
717722
for s in stmts loop
718-
ctrl := evaluateStatement(s);
723+
ctrl := evaluateStatement(s, context);
719724

720725
if ctrl <> FlowControl.NEXT then
721726
if ctrl == FlowControl.CONTINUE then
@@ -729,17 +734,18 @@ end evaluateStatements;
729734

730735
function evaluateStatement
731736
input Statement stmt;
737+
input InstContext.Type context;
732738
output FlowControl ctrl;
733739
algorithm
734740
// adrpo: we really need some error handling here to detect which statement cannot be evaluated
735741
// try
736742
ctrl := match stmt
737-
case Statement.ASSIGNMENT() then evaluateAssignment(stmt.lhs, stmt.rhs, stmt.source);
738-
case Statement.FOR() then evaluateFor(stmt.iterator, stmt.range, stmt.body, stmt.source);
739-
case Statement.IF() then evaluateIf(stmt.branches, stmt.source);
740-
case Statement.ASSERT() then evaluateAssert(stmt.condition, stmt);
741-
case Statement.NORETCALL() then evaluateNoRetCall(stmt.exp, stmt.source);
742-
case Statement.WHILE() then evaluateWhile(stmt.condition, stmt.body, stmt.source);
743+
case Statement.ASSIGNMENT() then evaluateAssignment(stmt.lhs, stmt.rhs, stmt.source, context);
744+
case Statement.FOR() then evaluateFor(stmt.iterator, stmt.range, stmt.body, stmt.source, context);
745+
case Statement.IF() then evaluateIf(stmt.branches, stmt.source, context);
746+
case Statement.ASSERT() then evaluateAssert(stmt.condition, stmt, stmt.source, context);
747+
case Statement.NORETCALL() then evaluateNoRetCall(stmt.exp, stmt.source, context);
748+
case Statement.WHILE() then evaluateWhile(stmt.condition, stmt.body, stmt.source, context);
743749
case Statement.RETURN() then FlowControl.RETURN;
744750
case Statement.BREAK() then FlowControl.BREAK;
745751
else
@@ -759,10 +765,11 @@ function evaluateAssignment
759765
input Expression lhsExp;
760766
input Expression rhsExp;
761767
input DAE.ElementSource source;
768+
input InstContext.Type context;
762769
output FlowControl ctrl = FlowControl.NEXT;
763770
algorithm
764771
assignVariable(lhsExp,
765-
Ceval.evalExp(rhsExp, EvalTarget.new(ElementSource.getInfo(source), STATEMENT_CONTEXT)));
772+
Ceval.evalExp(rhsExp, evalTargetFromSource(source, STATEMENT_CONTEXT, context)));
766773
end evaluateAssignment;
767774

768775
public
@@ -967,6 +974,7 @@ function evaluateFor
967974
input Option<Expression> range;
968975
input list<Statement> forBody;
969976
input DAE.ElementSource source;
977+
input InstContext.Type context;
970978
output FlowControl ctrl = FlowControl.NEXT;
971979
protected
972980
RangeIterator range_iter;
@@ -976,7 +984,7 @@ protected
976984
Integer i = 0, limit = Flags.getConfigInt(Flags.EVAL_LOOP_LIMIT);
977985
algorithm
978986
range_exp := Ceval.evalExp(Util.getOption(range),
979-
EvalTarget.new(ElementSource.getInfo(source), STATEMENT_CONTEXT));
987+
evalTargetFromSource(source, STATEMENT_CONTEXT, context));
980988
range_iter := RangeIterator.fromExp(range_exp);
981989

982990
if RangeIterator.hasNext(range_iter) then
@@ -987,7 +995,7 @@ algorithm
987995
(range_iter, value) := RangeIterator.next(range_iter);
988996
// Update the mutable expression with the iteration value and evaluate the statement.
989997
Mutable.update(iter_exp, value);
990-
ctrl := evaluateStatements(body);
998+
ctrl := evaluateStatements(body, context);
991999

9921000
if ctrl <> FlowControl.NEXT then
9931001
if ctrl == FlowControl.BREAK then
@@ -1010,6 +1018,7 @@ end evaluateFor;
10101018
function evaluateIf
10111019
input list<tuple<Expression, list<Statement>>> branches;
10121020
input DAE.ElementSource source;
1021+
input InstContext.Type context;
10131022
output FlowControl ctrl;
10141023
protected
10151024
Expression cond;
@@ -1018,8 +1027,8 @@ algorithm
10181027
for branch in branches loop
10191028
(cond, body) := branch;
10201029

1021-
if Expression.isTrue(Ceval.evalExp(cond, EvalTarget.new(ElementSource.getInfo(source), IF_COND_CONTEXT))) then
1022-
ctrl := evaluateStatements(body);
1030+
if Expression.isTrue(Ceval.evalExp(cond, evalTargetFromSource(source, IF_COND_CONTEXT, context))) then
1031+
ctrl := evaluateStatements(body, context);
10231032
return;
10241033
end if;
10251034
end for;
@@ -1030,11 +1039,12 @@ end evaluateIf;
10301039
function evaluateAssert
10311040
input Expression condition;
10321041
input Statement assertStmt;
1042+
input DAE.ElementSource source;
1043+
input InstContext.Type context;
10331044
output FlowControl ctrl = FlowControl.NEXT;
10341045
protected
10351046
Expression cond, msg, lvl;
1036-
SourceInfo info = ElementSource.getInfo(Statement.source(assertStmt));
1037-
EvalTarget target = EvalTarget.new(info, STATEMENT_CONTEXT);
1047+
EvalTarget target = evalTargetFromSource(source, STATEMENT_CONTEXT, context);
10381048
algorithm
10391049
if Expression.isFalse(Ceval.evalExp(condition, target)) then
10401050
Statement.ASSERT(message = msg, level = lvl) := assertStmt;
@@ -1044,13 +1054,13 @@ algorithm
10441054
() := match (msg, lvl)
10451055
case (Expression.STRING(), Expression.ENUM_LITERAL(name = "warning"))
10461056
algorithm
1047-
Error.addSourceMessage(Error.ASSERT_TRIGGERED_WARNING, {msg.value}, info);
1057+
Error.addSourceMessage(Error.ASSERT_TRIGGERED_WARNING, {msg.value}, EvalTarget.getInfo(target));
10481058
then
10491059
();
10501060

10511061
case (Expression.STRING(), Expression.ENUM_LITERAL(name = "error"))
10521062
algorithm
1053-
Error.addSourceMessage(Error.ASSERT_TRIGGERED_ERROR, {msg.value}, info);
1063+
Error.addSourceMessage(Error.ASSERT_TRIGGERED_ERROR, {msg.value}, EvalTarget.getInfo(target));
10541064
ctrl := FlowControl.ASSERTION;
10551065
then
10561066
();
@@ -1068,22 +1078,24 @@ end evaluateAssert;
10681078
function evaluateNoRetCall
10691079
input Expression callExp;
10701080
input DAE.ElementSource source;
1081+
input InstContext.Type context;
10711082
output FlowControl ctrl = FlowControl.NEXT;
10721083
algorithm
1073-
Ceval.evalExp(callExp, EvalTarget.new(ElementSource.getInfo(source), STATEMENT_CONTEXT));
1084+
Ceval.evalExp(callExp, evalTargetFromSource(source, STATEMENT_CONTEXT, context));
10741085
end evaluateNoRetCall;
10751086

10761087
function evaluateWhile
10771088
input Expression condition;
10781089
input list<Statement> body;
10791090
input DAE.ElementSource source;
1091+
input InstContext.Type context;
10801092
output FlowControl ctrl = FlowControl.NEXT;
10811093
protected
10821094
Integer i = 0, limit = Flags.getConfigInt(Flags.EVAL_LOOP_LIMIT);
1083-
EvalTarget target = EvalTarget.new(ElementSource.getInfo(source), STATEMENT_CONTEXT);
1095+
EvalTarget target = evalTargetFromSource(source, STATEMENT_CONTEXT, context);
10841096
algorithm
10851097
while Expression.isTrue(Ceval.evalExp(condition, target)) loop
1086-
ctrl := evaluateStatements(body);
1098+
ctrl := evaluateStatements(body, context);
10871099

10881100
if ctrl <> FlowControl.NEXT then
10891101
if ctrl == FlowControl.BREAK then
@@ -1102,6 +1114,13 @@ algorithm
11021114
end while;
11031115
end evaluateWhile;
11041116

1117+
function evalTargetFromSource
1118+
input DAE.ElementSource source;
1119+
input InstContext.Type context;
1120+
input InstContext.Type currentContext;
1121+
output EvalTarget target = EvalTarget.new(ElementSource.getInfo(source), InstContext.set(context, currentContext));
1122+
end evalTargetFromSource;
1123+
11051124
function evaluateExternal2
11061125
input String name;
11071126
input Function fn;

0 commit comments

Comments
 (0)