Skip to content

Commit

Permalink
Nb record residuals (#11827)
Browse files Browse the repository at this point in the history
* [NB] update alias module to not use record elements

[NSimCode] draft: inline record residuals

* [NB] add variable record children when adding records

* [NB] inline record residual equations

 - fixes #11556

* [testsuite] udapte tests

 - alias module does not apply to record elements anymore
  • Loading branch information
kabdelhak committed Jan 22, 2024
1 parent f0cfe93 commit 87beb39
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 57 deletions.
56 changes: 29 additions & 27 deletions OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo
Expand Up @@ -1964,64 +1964,66 @@ public
type VarType = enumeration(STATE, STATE_DER, ALGEBRAIC, DISCRETE, DISC_STATE, PREVIOUS, START, PARAMETER, ITERATOR, RECORD);

function addTypedList
"can also be used to add single variables"
input output VarData varData;
input list<Pointer<Variable>> var_lst;
input VarType varType;
algorithm
varData := match (varData, varType)

case (VAR_DATA_SIM(), VarType.STATE) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.states := VariablePointers.addList(var_lst, varData.states);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.states := VariablePointers.addList(var_lst, varData.states);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
// also remove from algebraics in the case it was moved
varData.unknowns := VariablePointers.removeList(var_lst, varData.unknowns);
varData.algebraics := VariablePointers.removeList(var_lst, varData.algebraics);
varData.unknowns := VariablePointers.removeList(var_lst, varData.unknowns);
varData.algebraics := VariablePointers.removeList(var_lst, varData.algebraics);
then varData;

case (VAR_DATA_SIM(), VarType.STATE_DER) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.derivatives := VariablePointers.addList(var_lst, varData.derivatives);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
then varData;

// algebraic variables, dummy states and dummy derivatives are mathematically equal
case (VAR_DATA_SIM(), VarType.ALGEBRAIC) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.algebraics := VariablePointers.addList(var_lst, varData.algebraics);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.algebraics := VariablePointers.addList(var_lst, varData.algebraics);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
then varData;

case (VAR_DATA_SIM(), VarType.DISCRETE) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.discretes := VariablePointers.addList(var_lst, varData.discretes);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.unknowns := VariablePointers.addList(var_lst, varData.unknowns);
varData.discretes := VariablePointers.addList(var_lst, varData.discretes);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
then varData;

case (VAR_DATA_SIM(), VarType.START) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.initials := VariablePointers.addList(var_lst, varData.initials);
then varData;

case (VAR_DATA_SIM(), VarType.PARAMETER) algorithm
varData.parameters := VariablePointers.addList(var_lst, varData.parameters);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.parameters := VariablePointers.addList(var_lst, varData.parameters);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
then varData;

case (VAR_DATA_SIM(), VarType.ITERATOR) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
then varData;

// IMPORTANT: does not add the record elements!
// IMPORTANT: requires the record elements to be added as children beforehand!
case (VAR_DATA_SIM(), VarType.RECORD) algorithm
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.records := VariablePointers.addList(var_lst, varData.records);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.variables := VariablePointers.addList(var_lst, varData.variables);
varData.records := VariablePointers.addList(var_lst, varData.records);
varData.knowns := VariablePointers.addList(var_lst, varData.knowns);
varData.records := VariablePointers.mapPtr(varData.records, function BackendDAE.lowerRecordChildren(variables = varData.variables));
then varData;

// ToDo: other cases
Expand Down
4 changes: 2 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -600,7 +600,7 @@ protected
end if;
end collectVariableBindingIterators;

function lowerRecordChildren
public function lowerRecordChildren
input Pointer<Variable> var_ptr;
input VariablePointers variables;
protected
Expand All @@ -623,7 +623,7 @@ protected
Pointer.update(var_ptr, var);
end lowerRecordChildren;

function lowerEquationData
protected function lowerEquationData
"Lowers all equations to backend structure.
kabdelhak: Splitting up the creation of the equation array and the equation
pointer arrays in two steps is slightly less effective, but way more readable
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBCausalize.mo
Expand Up @@ -208,7 +208,7 @@ public
Adjacency.Matrix adj;
Matching matching;
algorithm
// create scalar adjacency matrix for now
// create scalar adjacency matrix for now
adj := Adjacency.Matrix.create(vars, eqs, matrixType);
matching := Matching.regular(NBMatching.EMPTY_MATCHING, adj);
comps := Sorting.tarjan(adj, matching, vars, eqs);
Expand Down
4 changes: 2 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBAlias.mo
Expand Up @@ -460,9 +460,9 @@ protected
guard(BVariable.isParamOrConst(BVariable.getVarPointer(exp.cref)) or ComponentRef.isTime(exp.cref))
then tpl;

// fail for multidimensional crefs for now
// fail for multidimensional crefs and record elements for now
case Expression.CREF()
guard(BVariable.size(BVariable.getVarPointer(exp.cref)) > 1)
guard(BVariable.size(BVariable.getVarPointer(exp.cref)) > 1 or Util.isSome(BVariable.getParent(BVariable.getVarPointer(exp.cref))))
then FAILED_CREF_TPL;

// variable found
Expand Down
45 changes: 38 additions & 7 deletions OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBInline.mo
Expand Up @@ -62,6 +62,7 @@ protected
import NBVariable.{VariablePointers, VarData};

// Util
import Slice = NBSlice;
import StringUtil;

// =========================================================================
Expand Down Expand Up @@ -146,6 +147,33 @@ public
end if;
end functionInlineable;

function inlineRecords
"also inlines simple record equalities"
input output EqData eqData;
input VariablePointers variables;
protected
Pointer<Integer> index = EqData.getUniqueIndex(eqData);
Pointer<list<Pointer<Equation>>> new_eqns = Pointer.create({});
algorithm
eqData := EqData.map(eqData, function inlineRecordEquation(variables = variables, record_eqns = new_eqns, index = index, inlineSimple = true));
eqData := EqData.addUntypedList(eqData, Pointer.access(new_eqns), false);
eqData := EqData.compress(eqData);
end inlineRecords;

function inlineRecordSliceEquation
input Slice<Pointer<Equation>> slice;
input VariablePointers variables;
input Pointer<Integer> index;
input Boolean inlineSimple;
output list<Slice<Pointer<Equation>>> slices;
protected
Pointer<list<Pointer<Equation>>> record_eqns = Pointer.create({});
algorithm
inlineRecordEquation(Pointer.access(Slice.getT(slice)), variables, record_eqns, index, inlineSimple);
// somehow split slice.indices
slices := list(Slice.SLICE(eqn, {}) for eqn in Pointer.access(record_eqns));
end inlineRecordSliceEquation;

protected
function inline extends Module.inlineInterface;
protected
Expand Down Expand Up @@ -177,13 +205,14 @@ protected
end collectInlineFunctions;

function inlineRecordsTuples
"does not inline simple record equalities"
input output EqData eqData;
input VariablePointers variables;
protected
Pointer<Integer> index = EqData.getUniqueIndex(eqData);
Pointer<list<Pointer<Equation>>> new_eqns = Pointer.create({});
algorithm
eqData := EqData.map(eqData, function inlineRecordEquation(variables = variables, record_eqns = new_eqns, index = index));
eqData := EqData.map(eqData, function inlineRecordEquation(variables = variables, record_eqns = new_eqns, index = index, inlineSimple = false));
eqData := EqData.map(eqData, function inlineTupleEquation(tuple_eqns = new_eqns, index = index));
eqData := EqData.addUntypedList(eqData, Pointer.access(new_eqns), false);
eqData := EqData.compress(eqData);
Expand All @@ -197,21 +226,22 @@ protected
input VariablePointers variables;
input Pointer<list<Pointer<Equation>>> record_eqns;
input Pointer<Integer> index;
input Boolean inlineSimple;
algorithm
eqn := match eqn
local
Equation new_eqn;
Integer size;

// don't inline simple cref equalities
case Equation.RECORD_EQUATION(lhs = Expression.CREF(), rhs = Expression.CREF()) then eqn;
case Equation.ARRAY_EQUATION(lhs = Expression.CREF(), rhs = Expression.CREF()) then eqn;
case Equation.RECORD_EQUATION(lhs = Expression.CREF(), rhs = Expression.CREF()) guard(not inlineSimple) then eqn;
case Equation.ARRAY_EQUATION(lhs = Expression.CREF(), rhs = Expression.CREF()) guard(not inlineSimple) then eqn;

// try to inline other record equations. try catch to be sure to not discard
case Equation.RECORD_EQUATION(ty = Type.COMPLEX()) algorithm
try
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("[" + getInstanceName() + "] Inlining: " + Equation.toString(eqn) + "\n"); end if;
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, eqn.recordSize, variables, record_eqns, index);
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, eqn.recordSize, variables, record_eqns, index, inlineSimple);
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("\n"); end if;
else
// inlining failed, keep old equation
Expand All @@ -223,7 +253,7 @@ protected
case Equation.ARRAY_EQUATION(recordSize = SOME(size)) algorithm
try
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("[" + getInstanceName() + "] Inlining: " + Equation.toString(eqn) + "\n"); end if;
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, size, variables, record_eqns, index);
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, size, variables, record_eqns, index, inlineSimple);
else
// inlining failed, keep old equation
new_eqn := eqn;
Expand All @@ -232,7 +262,7 @@ protected

// iterate over body equations of for-loop
case Equation.FOR_EQUATION() algorithm
eqn.body := list(inlineRecordEquation(body_eqn, variables, record_eqns, index) for body_eqn in eqn.body);
eqn.body := list(inlineRecordEquation(body_eqn, variables, record_eqns, index, inlineSimple) for body_eqn in eqn.body);
then eqn;

else eqn;
Expand All @@ -248,6 +278,7 @@ protected
input VariablePointers variables;
input Pointer<list<Pointer<Equation>>> record_eqns;
input Pointer<Integer> index;
input Boolean inlineSimple;
output Equation new_eqn;
protected
list<Pointer<Equation>> tmp_eqns;
Expand All @@ -274,7 +305,7 @@ protected
// if the equation still has a record type, inline it further
if Equation.isRecordEquation(tmp_eqn) then
tmp_eqns_ptr := Pointer.create(tmp_eqns);
_ := inlineRecordEquation(Pointer.access(tmp_eqn), variables, tmp_eqns_ptr, index);
_ := inlineRecordEquation(Pointer.access(tmp_eqn), variables, tmp_eqns_ptr, index, inlineSimple);
tmp_eqns := Pointer.access(tmp_eqns_ptr);
else
tmp_eqns := tmp_eqn :: tmp_eqns;
Expand Down
46 changes: 34 additions & 12 deletions OMCompiler/Compiler/NBackEnd/Modules/3_Post/NBTearing.mo
Expand Up @@ -57,6 +57,7 @@ protected
import BJacobian = NBJacobian;
import BVariable = NBVariable;
import Differentiate = NBDifferentiate;
import Inline = NBInline;
import Jacobian = NBackendDAE.BackendDAE;
import Matching = NBMatching;
import Sorting = NBSorting;
Expand Down Expand Up @@ -119,28 +120,30 @@ public
bdae := match (systemType, bdae)
local
list<System.System> systems;
VariablePointers variables;
Pointer<Integer> eq_index;

case (NBSystem.SystemType.ODE, BackendDAE.MAIN(ode = systems, funcTree = funcTree))
case (NBSystem.SystemType.ODE, BackendDAE.MAIN(ode = systems, funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, func, funcTree, systemType);
(systems, funcTree) := tearingTraverser(systems, func, funcTree, variables, eq_index, systemType);
bdae.ode := systems;
bdae.funcTree := funcTree;
then bdae;

case (NBSystem.SystemType.INI, BackendDAE.MAIN(init = systems, funcTree = funcTree))
case (NBSystem.SystemType.INI, BackendDAE.MAIN(init = systems, funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, func, funcTree, systemType);
(systems, funcTree) := tearingTraverser(systems, func, funcTree, variables, eq_index, systemType);
bdae.init := systems;
if Util.isSome(bdae.init_0) then
(systems, funcTree) := tearingTraverser(Util.getOption(bdae.init_0), func, funcTree, systemType);
(systems, funcTree) := tearingTraverser(Util.getOption(bdae.init_0), func, funcTree, variables, eq_index, systemType);
bdae.init_0 := SOME(systems);
end if;
bdae.funcTree := funcTree;
then bdae;

case (NBSystem.SystemType.DAE, BackendDAE.MAIN(dae = SOME(systems), funcTree = funcTree))
case (NBSystem.SystemType.DAE, BackendDAE.MAIN(dae = SOME(systems), funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, func, funcTree, systemType);
(systems, funcTree) := tearingTraverser(systems, func, funcTree, variables, eq_index, systemType);
bdae.dae := SOME(systems);
bdae.funcTree := funcTree;
then bdae;
Expand All @@ -149,7 +152,11 @@ public
end match;
end main;

function implicit extends Module.tearingInterface;
function implicit
input output StrongComponent comp "the suspected algebraic loop.";
input output FunctionTree funcTree "Function call bodies";
input output Integer index "current unique loop index";
input System.SystemType systemType = NBSystem.SystemType.ODE "system type";
algorithm
(comp, funcTree, index) := match comp
// create implicit equations
Expand All @@ -160,7 +167,7 @@ public
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT), funcTree, index, systemType);
status = NBSolve.Status.IMPLICIT), funcTree, index, VariablePointers.empty(), Pointer.create(0), systemType);

case StrongComponent.MULTI_COMPONENT()
then tearingNone(StrongComponent.ALGEBRAIC_LOOP(
Expand All @@ -169,7 +176,7 @@ public
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT), funcTree, index, systemType);
status = NBSolve.Status.IMPLICIT), funcTree, index, VariablePointers.empty(), Pointer.create(0), systemType);

// do nothing otherwise
else (comp, funcTree, index);
Expand Down Expand Up @@ -222,6 +229,8 @@ protected
input Module.tearingInterface func;
output list<System.System> new_systems = {};
input output FunctionTree funcTree;
input VariablePointers variables;
input Pointer<Integer> eq_index;
input System.SystemType systemType;
protected
array<StrongComponent> strongComponents;
Expand All @@ -232,7 +241,7 @@ protected
if isSome(syst.strongComponents) then
SOME(strongComponents) := syst.strongComponents;
for i in 1:arrayLength(strongComponents) loop
(tmp, funcTree, idx) := func(strongComponents[i], funcTree, idx, systemType);
(tmp, funcTree, idx) := func(strongComponents[i], funcTree, idx, variables, eq_index, systemType);
// only update if it changed
if not referenceEq(tmp, strongComponents[i]) then
arrayUpdate(strongComponents, i, tmp);
Expand All @@ -251,14 +260,27 @@ protected
list<StrongComponent> residual_comps;
Option<Jacobian> jacobian;
Tearing strict;
protected
list<Slice<EquationPointer>> tmp;
list<list<Slice<EquationPointer>>> acc = {};
algorithm
(comp, index) := match comp
case StrongComponent.ALGEBRAIC_LOOP(strict = strict) algorithm
index := index + 1;
comp.idx := index;

for eqn in listReverse(strict.residual_eqns) loop
tmp := Inline.inlineRecordSliceEquation(eqn, variables, eq_index, true);
if listEmpty(tmp) then
acc := {eqn} :: acc;
else
acc := tmp :: acc;
end if;
end for;

// create residual equations
strict.residual_eqns := list(Slice.apply(eqn, function Equation.createResidual(new = true)) for eqn in strict.residual_eqns);
strict.residual_eqns := list(Slice.apply(eqn, function Equation.createResidual(new = true)) for eqn in List.flatten(acc));
comp.strict := strict;
residual_comps := list(StrongComponent.fromSolvedEquationSlice(eqn) for eqn in strict.residual_eqns);

// update jacobian to take slices (just to have correct inner variables and such)
Expand Down
2 changes: 2 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Modules/NBModule.mo
Expand Up @@ -287,6 +287,8 @@ public
input output StrongComponent comp "the suspected algebraic loop.";
input output FunctionTree funcTree "Function call bodies";
input output Integer index "current unique loop index";
input VariablePointers variables "all variables";
input Pointer<Integer> eq_index "equation index";
input System.SystemType systemType = NBSystem.SystemType.ODE "system type";
end tearingInterface;

Expand Down

0 comments on commit 87beb39

Please sign in to comment.