Skip to content

Commit

Permalink
[NB] update start equations (#11392)
Browse files Browse the repository at this point in the history
- fixes ticket #9986
 - creates start equations from start values and does not create $START variables if they are not constant
 - fixes scalarization ordering such that bindings and start values are correctly applied
  • Loading branch information
kabdelhak committed Oct 18, 2023
1 parent dc60bd1 commit c1d324b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 14 deletions.
24 changes: 17 additions & 7 deletions OMCompiler/Compiler/NBackEnd/Classes/NBVariable.mo
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public
//NF Imports
import Attributes = NFAttributes;
import BackendExtension = NFBackendExtension;
import NFBackendExtension.{BackendInfo, VariableKind};
import NFBackendExtension.{BackendInfo, VariableKind, VariableAttributes};
import NFBinding.Binding;
import ComponentRef = NFComponentRef;
import Dimension = NFDimension;
Expand Down Expand Up @@ -109,7 +109,7 @@ public
protected
String attr;
algorithm
attr := BackendExtension.VariableAttributes.toString(var.backendinfo.attributes);
attr := VariableAttributes.toString(var.backendinfo.attributes);
str := str + VariableKind.toString(var.backendinfo.varKind) + " (" + intString(Variable.size(var)) + ") " + Variable.toString(var) + (if attr == "" then "" else " " + attr);
end toString;

Expand Down Expand Up @@ -442,18 +442,23 @@ public
algorithm
b := match Pointer.access(var_ptr)
local
BackendExtension.VariableAttributes attributes;
VariableAttributes attributes;
case Variable.VARIABLE(backendinfo = BackendExtension.BACKEND_INFO(attributes = attributes))
then BackendExtension.VariableAttributes.getStateSelect(attributes) == stateSelect;
then VariableAttributes.getStateSelect(attributes) == stateSelect;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for " + toString(Pointer.access(var_ptr))});
then fail();
end match;
end isStateSelect;

function getVariableAttributes
input Variable var;
output VariableAttributes variableAttributes = var.backendinfo.attributes;
end getVariableAttributes;

function setVariableAttributes
input output Variable var;
input BackendExtension.VariableAttributes variableAttributes;
input VariableAttributes variableAttributes;
algorithm
var := match var
local
Expand Down Expand Up @@ -1091,7 +1096,7 @@ public
Expression start;

case Variable.VARIABLE(backendinfo = binfo as BackendExtension.BACKEND_INFO()) algorithm
binfo.attributes := BackendExtension.VariableAttributes.setFixed(binfo.attributes, var.ty, b);
binfo.attributes := VariableAttributes.setFixed(binfo.attributes, var.ty, b);
var.backendinfo := binfo;
then var;

Expand All @@ -1117,7 +1122,7 @@ public

case Variable.VARIABLE(backendinfo = binfo as BackendExtension.BACKEND_INFO()) algorithm
start := Binding.getExp(var.binding);
binfo.attributes := BackendExtension.VariableAttributes.setStartAttribute(binfo.attributes, start);
binfo.attributes := VariableAttributes.setStartAttribute(binfo.attributes, start);
var.backendinfo := binfo;
then var;

Expand All @@ -1136,6 +1141,11 @@ public
var_ptr := setFixed(var_ptr, b);
end setBindingAsStartAndFix;

function getStartAttribute
input Pointer<Variable> var_ptr;
output Option<Expression> start = VariableAttributes.getStartAttribute(getVariableAttributes(Pointer.access(var_ptr)));
end getStartAttribute;

function hasNonTrivialAliasBinding
"returns true if the binding does not represent a cref, a negated cref or a constant.
used for alias removal since only those can be stored as actual alias variables"
Expand Down
53 changes: 48 additions & 5 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBInitialization.mo
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ protected
import Jacobian = NBJacobian;
import Module = NBModule;
import Partitioning = NBPartitioning;
import Replacements = NBReplacements;
import NBSystem;
import NBSystem.System;
import Tearing = NBTearing;
Expand Down Expand Up @@ -166,6 +167,8 @@ public
Pointer<Variable> var_ptr, start_var;
Pointer<Equation> start_eq;
EquationKind kind;
Option<Expression> start_exp_opt;
Expression start_exp;

// if it is an array create for equation
case Variable.VARIABLE() guard BVariable.isFixed(state) and BVariable.isArray(state) algorithm
Expand All @@ -175,10 +178,20 @@ public
// create scalar equation
case Variable.VARIABLE() guard BVariable.isFixed(state) algorithm
name := BVariable.getVarName(state);
(var_ptr, name, start_var, start_name) := createStartVar(state, name, {});
start_exp_opt := BVariable.getStartAttribute(state);
if Util.isSome(start_exp_opt) and Expression.variability(Util.getOption(start_exp_opt)) > NFPrefixes.Variability.STRUCTURAL_PARAMETER then
// use the start attribute itself if it is not constant
SOME(start_exp) := start_exp_opt;
else
// create a start variable if it is constant
(var_ptr, name, start_var, start_name) := createStartVar(state, name, {});
start_exp := Expression.fromCref(start_name);
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
end if;

// make the new start equation
kind := if BVariable.isContinuous(state) then EquationKind.CONTINUOUS else EquationKind.DISCRETE;
start_eq := Equation.makeAssignment(name, Expression.fromCref(start_name), idx, NBEquation.START_STR, Iterator.EMPTY(), EquationAttributes.default(kind, true));
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
start_eq := Equation.makeAssignment(name, start_exp, idx, NBEquation.START_STR, Iterator.EMPTY(), EquationAttributes.default(kind, true));
Pointer.update(ptr_start_eqs, start_eq :: Pointer.access(ptr_start_eqs));
then ();

Expand Down Expand Up @@ -259,6 +272,8 @@ public
input Pointer<list<Pointer<Equation>>> ptr_start_eqs;
input Pointer<Integer> idx;
protected
Option<Expression> start_exp_opt;
Expression start_exp;
Pointer<Variable> var_ptr, start_var;
ComponentRef name, start_name;
list<Dimension> dims;
Expand All @@ -269,8 +284,13 @@ public
list<tuple<ComponentRef, Expression>> frames;
Pointer<Equation> start_eq;
EquationKind kind;
Call array_constructor;
UnorderedMap<ComponentRef, Expression> replacements;
InstNode old_iter;
ComponentRef new_iter;
algorithm
var_ptr := Slice.getT(state);
// make unique iterators for the new for-loop
name := BVariable.getVarName(var_ptr);
dims := Type.arrayDims(ComponentRef.nodeType(name));
(iterators, ranges, subscripts) := Flatten.makeIterators(name, dims);
Expand All @@ -279,13 +299,36 @@ public
subscripts := list(Subscript.mapExp(sub, BackendDAE.lowerIteratorExp) for sub in subscripts);
frames := List.zip(iter_crefs, ranges);
(var_ptr, name, start_var, start_name) := createStartVar(var_ptr, name, subscripts);

start_exp_opt := BVariable.getStartAttribute(var_ptr);
if Util.isSome(start_exp_opt) and Expression.variability(Util.getOption(start_exp_opt)) > NFPrefixes.Variability.STRUCTURAL_PARAMETER then
// use the start attribute itself if it is not constant
// discard start_var/start_name
SOME(start_exp) := start_exp_opt;
// if it is some kind of array repeating structure, extract the repeated element e.g. fill()
start_exp := match start_exp
case Expression.CALL(call = array_constructor as Call.TYPED_ARRAY_CONSTRUCTOR()) algorithm
replacements := UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual);
for tpl in List.zip(array_constructor.iters, frames) loop
((old_iter, _), (new_iter, _)) := tpl;
UnorderedMap.add(ComponentRef.fromNode(old_iter, InstNode.getType(old_iter)), Expression.fromCref(new_iter), replacements);
end for;
then Expression.map(array_constructor.exp, function Replacements.applySimpleExp(replacements = replacements));
else start_exp;
end match;
else
// create a start variable if it is constant
start_exp := Expression.fromCref(start_name);
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
end if;

// make the new start equation
kind := if BVariable.isContinuous(var_ptr) then EquationKind.CONTINUOUS else EquationKind.DISCRETE;
start_eq := Equation.makeAssignment(name, Expression.fromCref(start_name), idx, NBEquation.START_STR, Iterator.fromFrames(frames), EquationAttributes.default(kind, true));
start_eq := Equation.makeAssignment(name, start_exp, idx, NBEquation.START_STR, Iterator.fromFrames(frames), EquationAttributes.default(kind, true));
if not listEmpty(state.indices) then
// empty list indicates full array, slice otherwise
(start_eq, _, _) := Equation.slice(start_eq, state.indices, NONE(), FunctionTreeImpl.EMPTY());
end if;
Pointer.update(ptr_start_vars, start_var :: Pointer.access(ptr_start_vars));
Pointer.update(ptr_start_eqs, start_eq :: Pointer.access(ptr_start_eqs));
end createStartEquationSlice;

Expand Down
6 changes: 4 additions & 2 deletions OMCompiler/Compiler/NFFrontEnd/NFScalarize.mo
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,22 @@ protected
list<BackendInfo> backend_attributes;
algorithm
try
vars := listReverse(vars);
crefs := ComponentRef.scalarizeAll(ComponentRef.stripSubscriptsAll(var.name));
elem_ty := Type.arrayElementType(var.ty);
backend_attributes := BackendInfo.scalarize(var.backendinfo, listLength(crefs));
if Binding.isBound(var.binding) then
binding_iter := ExpressionIterator.fromExp(Binding.getTypedExp(var.binding), true);
bind_var := Binding.variability(var.binding);
bind_src := Binding.source(var.binding);
for cr in crefs loop
for cr in listReverse(crefs) loop
(binding_iter, exp) := ExpressionIterator.next(binding_iter);
binding := Binding.makeFlat(exp, bind_var, bind_src);
binfo :: backend_attributes := backend_attributes;
vars := Variable.VARIABLE(cr, elem_ty, binding, var.visibility, var.attributes, {}, {}, var.comment, var.info, binfo) :: vars;
end for;
else
for cr in crefs loop
for cr in listReverse(crefs) loop
binfo :: backend_attributes := backend_attributes;
vars := Variable.VARIABLE(cr, elem_ty, var.binding, var.visibility, var.attributes, {}, {}, var.comment, var.info, binfo) :: vars;
end for;
Expand All @@ -211,6 +212,7 @@ algorithm
else
Error.assertion(false, getInstanceName() + " failed for: " + Variable.toString(var), sourceInfo());
end try;
vars := listReverse(vars);
end scalarizeBackendVariable;

function scalarizeComplexVariable
Expand Down

0 comments on commit c1d324b

Please sign in to comment.