Skip to content

Commit

Permalink
[NB] update initialization (#8672)
Browse files Browse the repository at this point in the history
- fixed discrete states create $PRE.d = $START.d
 - balance initialization creates $PRE.d = $START.d for unmatched pre vars
 - balance initialization creates d = $PRE.d for unmatched discrete states
  • Loading branch information
kabdelhak committed Mar 10, 2022
1 parent 9e4f2b5 commit de8cb1e
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 1,782 deletions.
28 changes: 11 additions & 17 deletions OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Expand Up @@ -73,6 +73,10 @@ public
import StringUtil;
import UnorderedMap;

constant String SIMULATION_STR = "SIM";
constant String START_STR = "SRT";
constant String PRE_STR = "PRE";

type EquationPointer = Pointer<Equation> "mainly used for mapping purposes";
type Frame = tuple<ComponentRef, Expression> "iterator-like tuple for array handling";
type FrameLocation = tuple<array<Integer>, Frame> "sliced frame at specific sub locations";
Expand Down Expand Up @@ -605,11 +609,12 @@ public
end try;
end getResidualVar;

function makeStartEq
function makeEq
" x = $START.x"
input ComponentRef lhs;
input ComponentRef rhs;
input Pointer<Integer> idx;
input String str;
input list<Frame> frames = {};
output Pointer<Equation> eq;
protected
Expand Down Expand Up @@ -643,19 +648,8 @@ public
attr = EQ_ATTR_DEFAULT_INITIAL
));
end if;
Equation.createName(eq, idx, "SRT");
end makeStartEq;

function makePreEq
"$PRE.d = d"
input ComponentRef lhs;
input ComponentRef rhs;
input Pointer<Integer> idx;
output Pointer<Equation> eq;
algorithm
eq := Pointer.create(SIMPLE_EQUATION(ComponentRef.getSubscriptedType(lhs, true), lhs, rhs, DAE.emptyElementSource, EQ_ATTR_DEFAULT_INITIAL));
Equation.createName(eq, idx, "PRE");
end makePreEq;
Equation.createName(eq, idx, str);
end makeEq;

function forEquationToString
input Iterator iter "the iterator variable(s)";
Expand Down Expand Up @@ -2722,7 +2716,7 @@ public
case (EQ_DATA_SIM(), EqType.CONTINUOUS) algorithm
if newName then
for eqn_ptr in eq_lst loop
Equation.createName(eqn_ptr, eqData.uniqueIndex, "SIM");
Equation.createName(eqn_ptr, eqData.uniqueIndex, SIMULATION_STR);
end for;
end if;
eqData.equations := EquationPointers.addList(eq_lst, eqData.equations);
Expand All @@ -2733,7 +2727,7 @@ public
case (EQ_DATA_SIM(), EqType.DISCRETE) algorithm
if newName then
for eqn_ptr in eq_lst loop
Equation.createName(eqn_ptr, eqData.uniqueIndex, "SIM");
Equation.createName(eqn_ptr, eqData.uniqueIndex, SIMULATION_STR);
end for;
end if;
eqData.equations := EquationPointers.addList(eq_lst, eqData.equations);
Expand All @@ -2744,7 +2738,7 @@ public
case (EQ_DATA_SIM(), EqType.INITIAL) algorithm
if newName then
for eqn_ptr in eq_lst loop
Equation.createName(eqn_ptr, eqData.uniqueIndex, "SIM");
Equation.createName(eqn_ptr, eqData.uniqueIndex, SIMULATION_STR);
end for;
end if;
eqData.equations := EquationPointers.addList(eq_lst, eqData.equations);
Expand Down
35 changes: 30 additions & 5 deletions OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo
Expand Up @@ -458,14 +458,16 @@ public
end isFixed;

function isFixable
"states, discretes and parameters are fixable if they are not already fixed.
discrete states are always fixable. previous vars are only fixable if the discrete state for it wasn't fixed."
input Pointer<Variable> var;
output Boolean b;
algorithm
b := match Pointer.access(var)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.STATE())) then not isFixed(var);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE())) then not isFixed(var);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE_STATE())) then not isFixed(var);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS())) then not isFixed(var);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE_STATE())) then true;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS())) then not isFixed(getDiscreteStateVar(var));
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PARAMETER())) then not isFixed(var);
else false;
end match;
Expand Down Expand Up @@ -537,7 +539,7 @@ public

function makeStateVar
"Updates a variable pointer to be a state, requires the pointer to its derivative."
input output Pointer<Variable> varPointer;
input Pointer<Variable> varPointer;
input Pointer<Variable> derivative;
protected
Variable var;
Expand Down Expand Up @@ -682,6 +684,29 @@ public
end match;
end getDiscreteStateVar;

function getDiscreteStateCref
"Returns the discrete state variable component reference from a previous reference.
Only works after the discrete state has been detected by the DetectStates module and fails for non-previous crefs!"
input output ComponentRef cref;
algorithm
cref := match cref
local
Pointer<Variable> previous, state;
Variable stateVar;
case ComponentRef.CREF(node = InstNode.VAR_NODE(varPointer = previous)) then match Pointer.access(previous)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS(state = state)))
algorithm
stateVar := Pointer.access(state);
then stateVar.name;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong variable kind."});
then fail();
end match;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong InstNode type."});
then fail();
end match;
end getDiscreteStateCref;

function makeDummyState
input Pointer<Variable> varPointer;
Expand Down Expand Up @@ -735,7 +760,7 @@ public

function makeDiscreteStateVar
"Updates a discrete variable pointer to be a discrete state, requires the pointer to its left limit (pre) variable."
input output Pointer<Variable> varPointer;
input Pointer<Variable> varPointer;
input Pointer<Variable> previous;
protected
Variable var;
Expand Down Expand Up @@ -992,7 +1017,7 @@ public
// create the new variable pointer and safe it to the component reference
(var_ptr, cref) := makeVarPtrCyclic(var, cref);
(der_cref, der_var) := BVariable.makeDerVar(cref);
var_ptr := BVariable.makeStateVar(var_ptr, der_var);
BVariable.makeStateVar(var_ptr, der_var);
end makeAuxStateVar;

function getBindingVariability
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -590,7 +590,7 @@ protected
algorithm
equation_lst := lowerEquationsAndAlgorithms(eq_lst, al_lst, init_eq_lst, init_al_lst);
for eqn_ptr in equation_lst loop
BEquation.Equation.createName(eqn_ptr, idx, "SIM");
BEquation.Equation.createName(eqn_ptr, idx, NBEquation.SIMULATION_STR);
iterators := listAppend(Equation.getForIterators(Pointer.access(eqn_ptr)), iterators);
end for;
iterators := List.uniqueOnTrue(iterators, ComponentRef.isEqual);
Expand Down
19 changes: 10 additions & 9 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBCausalize.mo
Expand Up @@ -225,24 +225,25 @@ protected
(fixable, unfixable) := List.splitOnTrue(VariablePointers.toList(system.unknowns), BVariable.isFixable);
(initials, simulation) := List.splitOnTrue(EquationPointers.toList(system.equations), Equation.isInitial);

// ToDo: it should be:
// Phase I initial equations <-> unfixables (maybe simulation eqns first? to be tested)
// Phase II all equations <-> unfixables
// Phase III all equations <-> all variables

// #################################################
// Phase I: match sim equations <-> unfixable vars
// Phase I: match initial equations <-> unfixable vars
// #################################################
variables := VariablePointers.fromList(unfixable);
equations := EquationPointers.fromList(simulation);
equations := EquationPointers.fromList(initials);
adj := Adjacency.Matrix.create(variables, equations, NBAdjacency.MatrixType.PSEUDO, NBAdjacency.MatrixStrictness.SOLVABLE);
// do not resolve potential singular systems in Phase I! -> regular matching
matching := Matching.regular(Matching.EMPTY_MATCHING(), adj, true, true);

// #################################################
// Phase II: match all equations <-> all vars
// Phase II: match all equations <-> unfixables
// #################################################
(adj, variables, equations) := Adjacency.Matrix.expand(adj, variables, equations, {}, simulation);
matching := Matching.regular(matching, adj, true, true);

// #################################################
// Phase III: match all equations <-> all vars
// #################################################
(adj, variables, equations) := Adjacency.Matrix.expand(adj, variables, equations, fixable, initials);
(adj, variables, equations) := Adjacency.Matrix.expand(adj, variables, equations, fixable, {});
(matching, adj, variables, equations, funcTree, varData, eqData) := Matching.singular(matching, adj, variables, equations, funcTree, varData, eqData, system.systemType, false, true, false);

adj := Adjacency.Matrix.create(variables, equations, NBAdjacency.MatrixType.PSEUDO, NBAdjacency.MatrixStrictness.FULL);
Expand Down
142 changes: 103 additions & 39 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo
Expand Up @@ -83,8 +83,8 @@ public
eqData = eqData as BEquation.EQ_DATA_SIM(equations = equations, initials = initialEqs))
algorithm
// create the equations from fixed variables.
(variables, equations, initialEqs) := createStartEquations(varData.states, variables, equations, initialEqs, eqData.uniqueIndex);
(variables, equations, initialEqs) := createStartEquations(varData.discretes, variables, equations, initialEqs, eqData.uniqueIndex);
(variables, equations, initialEqs) := createStartEquations(varData.states, variables, equations, initialEqs, eqData.uniqueIndex, "State");
(variables, equations, initialEqs) := createStartEquations(varData.discretes, variables, equations, initialEqs, eqData.uniqueIndex, "Discrete State");
(equations, initialEqs, initialVars) := createParameterEquations(varData.parameters, equations, initialEqs, initialVars, eqData.uniqueIndex);

varData.variables := variables;
Expand Down Expand Up @@ -128,6 +128,7 @@ public
input output BEquation.EquationPointers equations;
input output BEquation.EquationPointers initialEqs;
input Pointer<Integer> idx;
input String str "only for debugging dump";
protected
Pointer<list<Pointer<Variable>>> ptr_start_vars = Pointer.create({});
Pointer<list<Pointer<BEquation.Equation>>> ptr_start_eqs = Pointer.create({});
Expand All @@ -143,7 +144,7 @@ public
initialEqs := BEquation.EquationPointers.addList(start_eqs, initialEqs);

if Flags.isSet(Flags.INITIALIZATION) and not listEmpty(start_eqs) then
print(List.toString(start_eqs, function Equation.pointerToString(str = ""), StringUtil.headline_4("Created Start Equations:"), "\t", "\n\t", "", false) + "\n\n");
print(List.toString(start_eqs, function Equation.pointerToString(str = ""), StringUtil.headline_4("Created " + str + " Start Equations:"), "\t", "\n\t", "", false) + "\n\n");
end if;
end createStartEquations;

Expand Down Expand Up @@ -172,52 +173,20 @@ public
algorithm
name := BVariable.getVarName(state);
(var_ptr, name, start_var, start_name) := createStartVar(state, name, {});
start_eq := BEquation.Equation.makeStartEq(name, start_name, idx);
start_eq := BEquation.Equation.makeEq(name, start_name, idx, NBEquation.START_STR);
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
Pointer.update(ptr_start_eqs, start_eq :: Pointer.access(ptr_start_eqs));
then ();
else ();
end match;
end createStartEquation;

function createStartEquationSlice
"creates a start equation for a sliced variable.
usually results in a for equation, but might be scalarized if that is not possible."
input Slice<VariablePointer> state;
input Pointer<list<Pointer<Variable>>> ptr_start_vars;
input Pointer<list<Pointer<BEquation.Equation>>> ptr_start_eqs;
input Pointer<Integer> idx;
protected
Pointer<Variable> var_ptr, start_var;
ComponentRef name, start_name;
list<Dimension> dims;
list<InstNode> iterators;
list<Expression> ranges;
list<Subscript> subscripts;
list<tuple<ComponentRef, Expression>> frames;
Pointer<Equation> start_eq;
algorithm
var_ptr := Slice.getT(state);
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
frames := List.zip(list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators), ranges);
(var_ptr, name, start_var, start_name) := createStartVar(var_ptr, name, subscripts);
start_eq := BEquation.Equation.makeStartEq(name, start_name, idx, frames);
if listEmpty(state.indices) then
// empty list indicates full array, slice otherwise
(start_eq, _, _) := Equation.slice(start_eq, state.indices, NONE(), FunctionTreeImpl.EMPTY());
end if;
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
Pointer.update(ptr_start_eqs, start_eq :: Pointer.access(ptr_start_eqs));
end createStartEquationSlice;

function createStartVar
"creates start variable and cref.
for discrete states the variable itself is changed to its
pre variable because they have to be initialized instead!.
normal: var = $START.var
disc state: $PRE.dst = $START.dst"
normal: var = $START.var
disc state and pre: $PRE.dst = $START.dst"
input output Pointer<Variable> var_ptr;
input output ComponentRef name;
input list<Subscript> subscripts;
Expand All @@ -227,11 +196,19 @@ public
Pointer<Variable> disc_state_var;
ComponentRef merged_name;
algorithm
merged_name := ComponentRef.mergeSubscripts(subscripts, name);
if BVariable.isDiscreteState(var_ptr) then
// for discrete states change the lhs cref to the $PRE cref
merged_name := ComponentRef.mergeSubscripts(subscripts, name);
name := BVariable.getPreCref(name);
name := ComponentRef.mergeSubscripts(subscripts, name);
var_ptr := BVariable.getVarPointer(name);
elseif BVariable.isPrevious(var_ptr) then
// for previous change the rhs to the start value of the discrete state
merged_name := BVariable.getDiscreteStateCref(name);
merged_name := ComponentRef.mergeSubscripts(subscripts, merged_name);
else
// just apply subscripts and make start var
merged_name := ComponentRef.mergeSubscripts(subscripts, name);
end if;
(start_name, start_var) := BVariable.makeStartVar(merged_name);
end createStartVar;
Expand Down Expand Up @@ -266,5 +243,92 @@ public
end if;
end createParameterEquations;

function createStartEquationSlice
"creates a start equation for a sliced variable.
usually results in a for equation, but might be scalarized if that is not possible."
input Slice<VariablePointer> state;
input Pointer<list<Pointer<Variable>>> ptr_start_vars;
input Pointer<list<Pointer<BEquation.Equation>>> ptr_start_eqs;
input Pointer<Integer> idx;
protected
Pointer<Variable> var_ptr, start_var;
ComponentRef name, start_name;
list<Dimension> dims;
list<InstNode> iterators;
list<Expression> ranges;
list<Subscript> subscripts;
list<tuple<ComponentRef, Expression>> frames;
Pointer<Equation> start_eq;
algorithm
var_ptr := Slice.getT(state);
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
frames := List.zip(list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators), ranges);
(var_ptr, name, start_var, start_name) := createStartVar(var_ptr, name, subscripts);
start_eq := BEquation.Equation.makeEq(name, start_name, idx, NBEquation.START_STR, frames);
if not listEmpty(state.indices) then
// empty list indicates full array, slice otherwise
(start_eq, _, _) := Equation.slice(start_eq, state.indices, NONE(), FunctionTreeImpl.EMPTY());
end if;
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
Pointer.update(ptr_start_eqs, start_eq :: Pointer.access(ptr_start_eqs));
end createStartEquationSlice;


function createPreEquation
"creates d = $PRE.d equations"
input Pointer<Variable> disc_state;
input Pointer<list<Pointer<BEquation.Equation>>> ptr_pre_eqs;
input Pointer<Integer> idx;
algorithm
_ := match Pointer.access(disc_state)
local
Pointer<Variable> previous;
Pointer<BEquation.Equation> pre_eq;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.VariableKind.DISCRETE_STATE(previous = previous)))
algorithm
pre_eq := BEquation.Equation.makeEq(BVariable.getVarName(disc_state), BVariable.getVarName(previous), idx, NBEquation.PRE_STR);
Pointer.update(ptr_pre_eqs, pre_eq :: Pointer.access(ptr_pre_eqs));
then ();
else ();
end match;
end createPreEquation;

function createPreEquationSlice
"creates a pre equation for a sliced variable.
usually results in a for equation, but might be scalarized if that is not possible."
input Slice<VariablePointer> disc_state;
input Pointer<list<Pointer<BEquation.Equation>>> ptr_pre_eqs;
input Pointer<Integer> idx;
protected
Pointer<Variable> var_ptr;
ComponentRef name, pre_name;
list<Dimension> dims;
list<InstNode> iterators;
list<Expression> ranges;
list<Subscript> subscripts;
list<tuple<ComponentRef, Expression>> frames;
Pointer<Equation> pre_eq;
algorithm
var_ptr := Slice.getT(disc_state);
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
frames := List.zip(list(ComponentRef.makeIterator(iter, Type.INTEGER()) for iter in iterators), ranges);

pre_name := BVariable.getPreCref(name);
pre_name := ComponentRef.mergeSubscripts(subscripts, pre_name);
name := ComponentRef.mergeSubscripts(subscripts, name);

pre_eq := BEquation.Equation.makeEq(name, pre_name, idx, NBEquation.PRE_STR);

if not listEmpty(disc_state.indices) then
// empty list indicates full array, slice otherwise
(pre_eq, _, _) := Equation.slice(pre_eq, disc_state.indices, NONE(), FunctionTreeImpl.EMPTY());
end if;
Pointer.update(ptr_pre_eqs, pre_eq :: Pointer.access(ptr_pre_eqs));
end createPreEquationSlice;

annotation(__OpenModelica_Interface="backend");
end NBInitialization;

0 comments on commit de8cb1e

Please sign in to comment.