Skip to content

Commit a233eb9

Browse files
[NB] check alias replacements for validity (#13662)
1 parent e7d6d52 commit a233eb9

File tree

7 files changed

+244
-9
lines changed

7 files changed

+244
-9
lines changed

OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ public
6565
import OldBackendDAE = BackendDAE;
6666

6767
// New Backend imports
68+
import DetectStates = NBDetectStates;
6869
import Evaluation = NBEvaluation;
6970
import Inline = NBInline;
7071
import Replacements = NBReplacements;
@@ -1441,6 +1442,8 @@ public
14411442
input output Equation eq;
14421443
input String name = "";
14431444
input String indent = "";
1445+
input Pointer<list<Pointer<Variable>>> acc_discrete_states = Pointer.create({});
1446+
input Pointer<list<Pointer<Variable>>> acc_previous = Pointer.create({});
14441447
input SimplifyFunc simplifyExp = function SimplifyExp.simplifyDump(includeScope = true, name = name, indent = indent);
14451448

14461449
partial function SimplifyFunc
@@ -1499,7 +1502,9 @@ public
14991502
case SOME(body) algorithm
15001503
eq.body := body;
15011504
then eq;
1502-
else Equation.DUMMY_EQUATION();
1505+
else algorithm
1506+
DetectStates.findDiscreteStatesFromWhenBody(eq.body, acc_discrete_states, acc_previous);
1507+
then Equation.DUMMY_EQUATION();
15031508
end match;
15041509
then new_eq;
15051510

OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,17 @@ public
225225
Pointer.update(par_ptr, par);
226226
end connectPartners;
227227

228+
function removePartner
229+
"removes the partner for the variable"
230+
input Pointer<Variable> var_ptr;
231+
input BackendInfo.setPartner func;
232+
protected
233+
Variable var = Pointer.access(var_ptr);
234+
algorithm
235+
var.backendinfo := func(var.backendinfo, NONE());
236+
Pointer.update(var_ptr, var);
237+
end removePartner;
238+
228239
function getVar
229240
input ComponentRef cref;
230241
output Variable var;

OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,14 @@ public
388388
input output BackendDAE bdae;
389389
input Boolean init;
390390
protected
391+
Pointer<list<Pointer<Variable>>> acc_discrete_states = Pointer.create({});
392+
Pointer<list<Pointer<Variable>>> acc_previous = Pointer.create({});
393+
391394
BEquation.MapFuncEqn func = function Equation.simplify(
392395
name = getInstanceName(),
393396
indent = "",
397+
acc_discrete_states = acc_discrete_states,
398+
acc_previous = acc_previous,
394399
simplifyExp = function SimplifyExp.simplifyDump(
395400
includeScope = true,
396401
name = getInstanceName(),
@@ -399,13 +404,40 @@ public
399404
bdae := match bdae
400405
local
401406
EqData eqData;
407+
VarData varData;
408+
list<Pointer<Variable>> acc_discrete_states_accessed;
409+
402410
case MAIN(eqData = eqData as BEquation.EQ_DATA_SIM()) algorithm
403411
if init then
404412
eqData.initials := EquationPointers.map(eqData.initials, func);
405413
else
406414
eqData.equations := EquationPointers.map(eqData.equations, func);
407415
end if;
408416
bdae.eqData := EqData.compress(eqData);
417+
418+
// update varData with accs obtained from mapping
419+
bdae.varData := match bdae.varData
420+
case varData as VarData.VAR_DATA_SIM() algorithm
421+
acc_discrete_states_accessed := Pointer.access(acc_discrete_states);
422+
423+
VariablePointers.removeList(acc_discrete_states_accessed, varData.unknowns);
424+
VariablePointers.removeList(acc_discrete_states_accessed, varData.discretes);
425+
VariablePointers.removeList(acc_discrete_states_accessed, varData.discrete_states);
426+
// TODO: CLOCKED?
427+
428+
VariablePointers.removeList(Pointer.access(acc_previous), varData.previous);
429+
VariablePointers.removeList(Pointer.access(acc_previous), varData.variables);
430+
431+
VariablePointers.addList(acc_discrete_states_accessed, varData.parameters);
432+
VariablePointers.addList(acc_discrete_states_accessed, varData.knowns);
433+
434+
for v in acc_discrete_states_accessed loop
435+
BVariable.setVarKind(v, VariableKind.PARAMETER(NONE()));
436+
BVariable.removePartner(v, BackendInfo.setVarPre);
437+
end for;
438+
then varData;
439+
else bdae.varData;
440+
end match;
409441
then bdae;
410442
else bdae;
411443
end match;
@@ -420,12 +452,17 @@ public
420452
bdae := match bdae
421453
local
422454
EqData eqData;
455+
Pointer<list<Pointer<Variable>>> acc_discrete_states = Pointer.create({});
456+
Pointer<list<Pointer<Variable>>> acc_previous = Pointer.create({});
457+
423458
case MAIN(eqData = eqData as BEquation.EQ_DATA_SIM()) algorithm
424459
eqData.equations := EquationPointers.map(
425460
eqData.equations,
426461
function Equation.simplify(
427462
name = getInstanceName(),
428463
indent = "",
464+
acc_discrete_states = acc_discrete_states,
465+
acc_previous = acc_previous,
429466
simplifyExp = SimplifyExp.removeStream));
430467
bdae.eqData := EqData.compress(eqData);
431468
then bdae;
@@ -1245,7 +1282,7 @@ protected
12451282
eq := Pointer.create(Equation.ALGORITHM(size, alg, alg.source, DAE.EXPAND(), attr));
12461283
end lowerAlgorithm;
12471284

1248-
protected function lowerEquationAttributes
1285+
function lowerEquationAttributes
12491286
input Type ty;
12501287
input Boolean init;
12511288
output EquationAttributes attr;
@@ -1259,7 +1296,7 @@ protected
12591296
end if;
12601297
end lowerEquationAttributes;
12611298

1262-
function lowerComponentReferences
1299+
protected function lowerComponentReferences
12631300
input output EquationPointers equations;
12641301
input VariablePointers variables;
12651302
algorithm

OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBAlias.mo

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ protected
7575
import ComponentRef = NFComponentRef;
7676
import Expression = NFExpression;
7777
import ExpressionIterator = NFExpressionIterator;
78+
import NFFunction.Function;
7879
import Type = NFType;
7980
import Operator = NFOperator;
8081
import Variable = NFVariable;
@@ -220,14 +221,15 @@ protected
220221
UnorderedMap<ComponentRef, Expression> replacements;
221222
EquationPointers newEquations;
222223
list<Pointer<Variable>> alias_vars, const_vars, non_trivial_alias;
223-
list<Pointer<Equation>> non_trivial_eqs;
224+
list<Pointer<Equation>> non_trivial_eqs, auxEquations;
224225

225226
case (BVariable.VAR_DATA_SIM(), BEquation.EQ_DATA_SIM())
226227
algorithm
227228
// -----------------------------------
228229
// 1. 2. 3.
229230
// -----------------------------------
230231
(replacements, newEquations) := aliasCausalize(varData.unknowns, eqData.simulation, "Simulation");
232+
(replacements, auxEquations) := checkReplacements(replacements, eqData);
231233

232234
// -----------------------------------
233235
// 4. apply replacements
@@ -273,14 +275,92 @@ protected
273275
non_trivial_eqs := list(Equation.generateBindingEquation(var, eqData.uniqueIndex, false) for var in non_trivial_alias);
274276
eqData.removed := EquationPointers.addList(non_trivial_eqs, eqData.removed);
275277
//eqData.equations := EquationPointers.addList(non_trivial_eqs, eqData.equations);
276-
then (varData, eqData);
278+
then (varData, EqData.addUntypedList(eqData, auxEquations, false));
277279

278280
else algorithm
279281
Error.addMessage(Error.INTERNAL_ERROR, {getInstanceName() + " failed."});
280282
then fail();
281283
end match;
282284
end aliasDefault;
283285

286+
function checkReplacements
287+
"Checks validity of all replacements, returns all valid replacements and auxiliary equations"
288+
input UnorderedMap<ComponentRef, Expression> replacements;
289+
input EqData eqData;
290+
output UnorderedMap<ComponentRef, Expression> newReplacements = UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual);
291+
output list<Pointer<Equation>> auxEquations = {};
292+
protected
293+
UnorderedSet<ComponentRef> exceptionSet = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
294+
ComponentRef cref;
295+
Expression exp;
296+
Pointer<Equation> eqPtr;
297+
EquationAttributes attr;
298+
algorithm
299+
EqData.mapExp(eqData, function filterPre(acc = exceptionSet));
300+
for keyValueTpl in UnorderedMap.toList(replacements) loop
301+
(cref, exp) := keyValueTpl;
302+
if isValidReplacement(cref, exp, exceptionSet) then
303+
// replacement is valid - add to newReplacements
304+
UnorderedMap.add(cref, exp, newReplacements);
305+
else
306+
// add auxiliary equation
307+
attr := BackendDAE.lowerEquationAttributes(ComponentRef.getSubscriptedType(cref), false);
308+
eqPtr := Equation.makeAssignment(Expression.fromCref(cref), exp, EqData.getUniqueIndex(eqData), "SIM", Iterator.EMPTY(), attr);
309+
auxEquations := eqPtr :: auxEquations;
310+
end if;
311+
end for;
312+
313+
if Flags.isSet(Flags.DUMP_REPL) then
314+
dumpReplacements(newReplacements, auxEquations);
315+
end if;
316+
end checkReplacements;
317+
318+
function isValidReplacement
319+
"Checks if a replacement (cref, exp) is valid"
320+
input ComponentRef cref;
321+
input Expression exp;
322+
input UnorderedSet<ComponentRef> exceptionSet;
323+
output Boolean b = true;
324+
algorithm
325+
// TODO: possibly match cref, exp here: add if needed
326+
if UnorderedSet.contains(cref, exceptionSet) then
327+
b := false;
328+
end if;
329+
end isValidReplacement;
330+
331+
function filterPre
332+
"Filter expression for pre call"
333+
input output Expression exp;
334+
input UnorderedSet<ComponentRef> acc;
335+
algorithm
336+
() := match exp
337+
local
338+
Call call;
339+
ComponentRef cref;
340+
341+
case Expression.CALL(call = call as Call.TYPED_CALL(arguments = {Expression.CREF(cref = cref)}))
342+
guard(AbsynUtil.pathString(Function.nameConsiderBuiltin(call.fn)) == "pre") algorithm
343+
UnorderedSet.add(cref, acc);
344+
then ();
345+
346+
else ();
347+
end match;
348+
end filterPre;
349+
350+
function dumpReplacements
351+
input UnorderedMap<ComponentRef, Expression> replacements;
352+
input list<Pointer<Equation>> auxEquations = {};
353+
algorithm
354+
print(Replacements.simpleToString(replacements) + "\n");
355+
if not listEmpty(auxEquations) then
356+
print(StringUtil.headline_4("[dumprepl] Found But Illegal Alias Replacements (added as equations):"));
357+
for eqPtr in auxEquations loop
358+
print("\t" + Equation.toString(Pointer.access(eqPtr)) + "\n");
359+
end for;
360+
print("\n");
361+
end if;
362+
end dumpReplacements;
363+
284364
function aliasClocks
285365
"STEPS:
286366
1. collect alias sets (variables, equations, optional constant binding)
@@ -296,13 +376,15 @@ protected
296376
UnorderedMap<ComponentRef, Expression> replacements;
297377
EquationPointers newEquations;
298378
list<Pointer<Variable>> alias_vars;
379+
list<Pointer<Equation>> auxEquations;
299380

300381
case (BVariable.VAR_DATA_SIM(), BEquation.EQ_DATA_SIM())
301382
algorithm
302383
// -----------------------------------
303384
// 1. 2. 3.
304385
// -----------------------------------
305386
(replacements, newEquations) := aliasCausalize(varData.clocks, eqData.clocked, "Clocked");
387+
(replacements, auxEquations) := checkReplacements(replacements, eqData);
306388

307389
// -----------------------------------
308390
// 4. apply replacements
@@ -317,7 +399,7 @@ protected
317399
// remove alias variables from clocks and add to alias
318400
varData.clocks := VariablePointers.removeList(alias_vars, varData.clocks);
319401
varData.aliasVars := VariablePointers.addList(alias_vars, varData.aliasVars);
320-
then (varData, eqData);
402+
then (varData, EqData.addUntypedList(eqData, auxEquations, false));
321403

322404
else algorithm
323405
Error.addMessage(Error.INTERNAL_ERROR, {getInstanceName() + " failed."});
@@ -371,9 +453,6 @@ protected
371453
replacements := createReplacementRules(set, replacements);
372454
end for;
373455

374-
if Flags.isSet(Flags.DUMP_REPL) then
375-
print(Replacements.simpleToString(replacements) + "\n");
376-
end if;
377456
end aliasCausalize;
378457

379458
function findSimpleEquation

OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBDetectStates.mo

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,33 @@ protected
584584
end if;
585585
end getPreVar;
586586

587+
public function findDiscreteStatesFromWhenBody
588+
"All variables on the LHS in a when equation are considered discrete, add these to acc lists"
589+
input WhenEquationBody body;
590+
input Pointer<list<Pointer<Variable>>> acc_discrete_states;
591+
input Pointer<list<Pointer<Variable>>> acc_previous;
592+
algorithm
593+
for body_stmt in body.when_stmts loop
594+
() := match body_stmt
595+
local
596+
ComponentRef state_cref, pre_cref;
597+
Pointer<Variable> state_var, pre_var;
598+
599+
case WhenStatement.ASSIGN(lhs = Expression.CREF(cref = state_cref)) algorithm
600+
state_var := BVariable.getVarPointer(state_cref);
601+
_ := match BVariable.getVarPre(state_var)
602+
case SOME(pre_var) algorithm
603+
Pointer.update(acc_previous, pre_var :: Pointer.access(acc_previous));
604+
then ();
605+
else ();
606+
end match;
607+
Pointer.update(acc_discrete_states, state_var :: Pointer.access(acc_discrete_states));
608+
then ();
609+
else ();
610+
end match;
611+
end for;
612+
end findDiscreteStatesFromWhenBody;
613+
587614
annotation(__OpenModelica_Interface="backend");
588615
end NBDetectStates;
589616

0 commit comments

Comments
 (0)