Skip to content

Commit

Permalink
[NB] udapte discrete state handling (#11405)
Browse files Browse the repository at this point in the history
* [NB] udapte discrete state handling

 - only collect discrete states from LHS of when equations
 - collect pre variables for all variables and dont assume they are discrete states
 - update start/pre initial equations
 - update variable structure
 - update -d=backenddaeinfo
  • Loading branch information
kabdelhak committed Oct 24, 2023
1 parent dab3069 commit 69d9151
Show file tree
Hide file tree
Showing 17 changed files with 321 additions and 300 deletions.
9 changes: 9 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Classes/NBStrongComponent.mo
Expand Up @@ -306,6 +306,15 @@ public
end match;
end isEqual;

function removeAlias
input output StrongComponent comp;
algorithm
comp := match comp
case ALIAS() then comp.original;
else comp;
end match;
end removeAlias;

function createPseudoSlice
input Integer eqn_arr_idx;
input ComponentRef cref_to_solve;
Expand Down
16 changes: 16 additions & 0 deletions OMCompiler/Compiler/NBackEnd/Classes/NBSystem.mo
Expand Up @@ -304,6 +304,22 @@ public
end if;
end clone;

function removeAlias
"removes alias strong components and replaces it with their original strong components.
used before differentiating for jacobians."
input output System sys;
protected
array<StrongComponent> comps;
algorithm
if Util.isSome(sys.strongComponents) then
// no need to override comps afterwards since arrays are mutable
comps := Util.getOption(sys.strongComponents);
for i in 1:arrayLength(comps) loop
comps[i] := StrongComponent.removeAlias(comps[i]);
end for;
end if;
end removeAlias;

protected
function partitionKindString
input PartitionKind partitionKind;
Expand Down
152 changes: 68 additions & 84 deletions OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo
Expand Up @@ -87,7 +87,7 @@ public
constant Variable TIME_VARIABLE = Variable.VARIABLE(NFBuiltin.TIME_CREF, Type.REAL(),
NFBinding.EMPTY_BINDING, NFPrefixes.Visibility.PUBLIC, NFAttributes.DEFAULT_ATTR,
{}, {}, NONE(), SCodeUtil.dummyInfo, BackendExtension.BACKEND_INFO(
VariableKind.TIME(), NFBackendExtension.EMPTY_VAR_ATTR_REAL, NFBackendExtension.EMPTY_ANNOTATIONS));
VariableKind.TIME(), NFBackendExtension.EMPTY_VAR_ATTR_REAL, NFBackendExtension.EMPTY_ANNOTATIONS, NONE()));

constant String DERIVATIVE_STR = "$DER";
constant String DUMMY_DERIVATIVE_STR = "$dDER";
Expand Down Expand Up @@ -170,6 +170,20 @@ public
Pointer.update(var_ptr, var);
end makeVarPtrCyclic;

function connectPrePostVar
"sets the pre() var for the variable and also sets the variable pointer at the pre() variable"
input Pointer<Variable> var_ptr;
input Pointer<Variable> pre_ptr;
protected
Variable var = Pointer.access(var_ptr);
Variable pre = Pointer.access(pre_ptr);
algorithm
var.backendinfo := BackendInfo.setPrePost(var.backendinfo, SOME(pre_ptr));
pre.backendinfo := BackendInfo.setPrePost(pre.backendinfo, SOME(var_ptr));
Pointer.update(var_ptr, var);
Pointer.update(pre_ptr, pre);
end connectPrePostVar;

function getVar
input ComponentRef cref;
output Variable var;
Expand Down Expand Up @@ -331,6 +345,39 @@ public
end match;
end isPrevious;

function getPrePost
"gets the pre() / previous() var if its a variable / clocked variable or the other way around"
input Pointer<Variable> var_ptr;
output Option<Pointer<Variable>> pre_post;
protected
Variable var = Pointer.access(var_ptr);
algorithm
pre_post := var.backendinfo.pre_post;
end getPrePost;

function getPrePostCref
"only use if you are sure there is a pre-post variable"
input ComponentRef cref;
output ComponentRef pre_post;
protected
Option<Pointer<Variable>> pre_post_opt;
algorithm
pre_post_opt := getPrePost(getVarPointer(cref));
if Util.isSome(pre_post_opt) then
pre_post := getVarName(Util.getOption(pre_post_opt));
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because it had no pre or post variable."});
fail();
end if;
end getPrePostCref;

function hasPre
"only returns true if the variable itself is not a pre() or previous() and has a pre() pointer set"
extends checkVar;
algorithm
b := not isPrevious(var_ptr) and Util.isSome(getPrePost(var_ptr));
end hasPre;

function isDummyState extends checkVar;
algorithm
b := match Pointer.access(var_ptr)
Expand Down Expand Up @@ -427,10 +474,11 @@ public
algorithm
b := match Pointer.access(var_ptr)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.STATE())) then not isFixed(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE())) then not isFixed(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE_STATE())) then true;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS())) then not isFixed(getDiscreteStateVar(var_ptr));
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.ALGEBRAIC())) then not isFixed(var_ptr) or hasPre(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE())) then not isFixed(var_ptr) or hasPre(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE_STATE())) then not isFixed(var_ptr) or hasPre(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PARAMETER())) then not isFixed(var_ptr);
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS())) then true;
else false;
end match;
end isFixable;
Expand Down Expand Up @@ -574,8 +622,6 @@ public
state_var := match Pointer.access(der_var)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.STATE_DER(state = state_var)))
then state_var;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS(state = state_var)))
then state_var;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + pointerToString(der_var) + " because of wrong variable kind."});
then fail();
Expand All @@ -596,10 +642,6 @@ public
algorithm
stateVar := Pointer.access(state);
then stateVar.name;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS(state = state)))
algorithm
stateVar := Pointer.access(state);
then stateVar.name;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong variable kind."});
then fail();
Expand Down Expand Up @@ -647,43 +689,6 @@ public
end match;
end getDerCref;

function getDiscreteStateVar
input Pointer<Variable> pre_var;
output Pointer<Variable> state_var;
algorithm
state_var := match Pointer.access(pre_var)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS(state = state_var)))
then state_var;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + pointerToString(pre_var) + " because of wrong variable kind."});
then fail();
end match;
end getDiscreteStateVar;

function getDiscreteStateCref
"Returns the discrete state variable component reference from a previous reference.
Only works after the discrete state has been detected by the DetectStates module and fails for non-previous crefs!"
input output ComponentRef cref;
algorithm
cref := match cref
local
Pointer<Variable> previous, state;
Variable stateVar;
case ComponentRef.CREF(node = InstNode.VAR_NODE(varPointer = previous)) then match Pointer.access(previous)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.PREVIOUS(state = state)))
algorithm
stateVar := Pointer.access(state);
then stateVar.name;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong variable kind."});
then fail();
end match;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong InstNode type."});
then fail();
end match;
end getDiscreteStateCref;

function getRecordChildren
"returns all children of the variable if its a record, otherwise returns empty list"
input Pointer<Variable> var;
Expand Down Expand Up @@ -751,35 +756,34 @@ public
function makeDiscreteStateVar
"Updates a discrete variable pointer to be a discrete state, requires the pointer to its left limit (pre) variable."
input Pointer<Variable> varPointer;
input Pointer<Variable> previous;
protected
Variable var;
Variable var = Pointer.access(varPointer);
algorithm
var := Pointer.access(varPointer);
var.backendinfo := BackendExtension.BackendInfo.setVarKind(var.backendinfo, BackendExtension.DISCRETE_STATE(previous, false));
var.backendinfo := BackendExtension.BackendInfo.setVarKind(var.backendinfo, BackendExtension.DISCRETE_STATE(false));
Pointer.update(varPointer, var);
end makeDiscreteStateVar;

function makePreVar
"Creates a previous variable pointer from the discrete variable cref.
"Creates a previous variable pointer from the variable cref.
e.g. isOpen -> $PRE.isOpen"
input ComponentRef cref "old component reference";
output ComponentRef pre_cref "new component reference";
output Pointer<Variable> var_ptr "pointer to new variable";
output Pointer<Variable> pre_ptr "pointer to new variable";
algorithm
() := match ComponentRef.node(cref)
local
InstNode qual;
Pointer<Variable> disc;
Variable var;
Pointer<Variable> var_ptr;
Variable pre;
case qual as InstNode.VAR_NODE()
algorithm
disc := BVariable.getVarPointer(cref);
var_ptr := BVariable.getVarPointer(cref);
qual.name := PREVIOUS_STR;
pre_cref := ComponentRef.append(cref, ComponentRef.fromNode(qual, ComponentRef.scalarType(cref)));
var := fromCref(pre_cref, Variable.attributes(Pointer.access(disc)));
var.backendinfo := BackendExtension.BackendInfo.setVarKind(var.backendinfo, BackendExtension.PREVIOUS(disc));
(var_ptr, pre_cref) := makeVarPtrCyclic(var, pre_cref);
pre := fromCref(pre_cref, Variable.attributes(Pointer.access(var_ptr)));
pre.backendinfo := BackendExtension.BackendInfo.setVarKind(pre.backendinfo, BackendExtension.PREVIOUS());
(pre_ptr, pre_cref) := makeVarPtrCyclic(pre, pre_cref);
connectPrePostVar(var_ptr, pre_ptr);
then ();

else algorithm
Expand All @@ -788,30 +792,6 @@ public
end match;
end makePreVar;

function getPreCref
"Returns the previous variable component reference from a discrete componet reference.
Only works after the discrete state has been detected by the DetectStates module and fails for non-discrete-state crefs!"
input output ComponentRef cref;
algorithm
cref := match cref
local
Pointer<Variable> disc, previous;
Variable preVar;
case ComponentRef.CREF(node = InstNode.VAR_NODE(varPointer = disc)) then match Pointer.access(disc)
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(varKind = BackendExtension.DISCRETE_STATE(previous = previous)))
algorithm
preVar := Pointer.access(previous);
then preVar.name;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong variable kind."});
then fail();
end match;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + ComponentRef.toString(cref) + " because of wrong InstNode type."});
then fail();
end match;
end getPreCref;

function makeSeedVar
"Creates a seed variable pointer from a cref. Used in NBJacobian and NBHessian
to represent generic gradient equations.
Expand Down Expand Up @@ -1696,11 +1676,13 @@ public
VariablePointers derivatives "State derivatives (der(x) -> $DER.x)";
VariablePointers algebraics "Algebraic variables";
VariablePointers discretes "Discrete variables";
VariablePointers previous "Previous discrete variables (pre(d) -> $PRE.d)";
VariablePointers discrete_states "Discrete state variables";
VariablePointers previous "Previous variables (pre(d) -> $PRE.d)";
// clocked

/* subset of knowns */
VariablePointers states "States";
VariablePointers top_level_inputs "Top level inputs";
VariablePointers parameters "Parameters";
VariablePointers constants "Constants";
VariablePointers records "Records";
Expand Down Expand Up @@ -1828,7 +1810,9 @@ public
VariablePointers.toString(varData.derivatives, "Derivative", false) +
VariablePointers.toString(varData.algebraics, "Algebraic", false) +
VariablePointers.toString(varData.discretes, "Discrete", false) +
VariablePointers.toString(varData.discrete_states, "Discrete States", false) +
VariablePointers.toString(varData.previous, "Previous", false) +
VariablePointers.toString(varData.top_level_inputs, "Top Level Inputs", false) +
VariablePointers.toString(varData.parameters, "Parameter", false) +
VariablePointers.toString(varData.constants, "Constant", false) +
VariablePointers.toString(varData.records, "Record", false) +
Expand Down

0 comments on commit 69d9151

Please sign in to comment.