Skip to content

Commit

Permalink
[NB] refine adjacency matrix (#12298)
Browse files Browse the repository at this point in the history
* [NB] start minimal tearing

* [NB] update strong component after tearing method

* [NB] new tearing utility

 - start Adjacency.Matrix.refine which refines solvabilty information
 - apply tearingFinalize on implicitely solved equations
  • Loading branch information
kabdelhak committed Apr 23, 2024
1 parent a842deb commit 188b3c3
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 32 deletions.
75 changes: 44 additions & 31 deletions OMCompiler/Compiler/NBackEnd/Modules/3_Post/NBTearing.mo
Expand Up @@ -48,6 +48,7 @@ protected
import Tearing = NBTearing;

// NF imports
import ComponentRef = NFComponentRef;
import NFFlatten.FunctionTree;
import Variable = NFVariable;

Expand Down Expand Up @@ -121,30 +122,29 @@ public
bdae := match (systemType, bdae)
local
list<System.System> systems;
VariablePointers variables;
Pointer<Integer> eq_index;

case (NBSystem.SystemType.ODE, BackendDAE.MAIN(ode = systems, funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
case (NBSystem.SystemType.ODE, BackendDAE.MAIN(ode = systems, funcTree = funcTree, eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, variables, eq_index, systemType);
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, eq_index, systemType);
bdae.ode := systems;
bdae.funcTree := funcTree;
then bdae;

case (NBSystem.SystemType.INI, BackendDAE.MAIN(init = systems, funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
case (NBSystem.SystemType.INI, BackendDAE.MAIN(init = systems, funcTree = funcTree, eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, variables, eq_index, systemType);
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, eq_index, systemType);
bdae.init := systems;
if Util.isSome(bdae.init_0) then
(systems, funcTree) := tearingTraverser(Util.getOption(bdae.init_0), funcs, funcTree, variables, eq_index, systemType);
(systems, funcTree) := tearingTraverser(Util.getOption(bdae.init_0), funcs, funcTree, eq_index, systemType);
bdae.init_0 := SOME(systems);
end if;
bdae.funcTree := funcTree;
then bdae;

case (NBSystem.SystemType.DAE, BackendDAE.MAIN(dae = SOME(systems), funcTree = funcTree, varData = BVariable.VAR_DATA_SIM(variables = variables), eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
case (NBSystem.SystemType.DAE, BackendDAE.MAIN(dae = SOME(systems), funcTree = funcTree, eqData = BEquation.EQ_DATA_SIM(uniqueIndex = eq_index)))
algorithm
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, variables, eq_index, systemType);
(systems, funcTree) := tearingTraverser(systems, funcs, funcTree, eq_index, systemType);
bdae.dae := SOME(systems);
bdae.funcTree := funcTree;
then bdae;
Expand All @@ -161,26 +161,36 @@ public
protected
// dummy adjacency matrix, dont need it for tearingNone()
Adjacency.Matrix dummy = Adjacency.EMPTY(NBAdjacency.MatrixStrictness.FULL);
list<Module.tearingInterface> funcs = {tearingNone, tearingFinalize};
StrongComponent new_comp;
algorithm
(comp, dummy, funcTree, index) := match comp
// create implicit equations
case StrongComponent.SINGLE_COMPONENT()
then tearingNone(StrongComponent.ALGEBRAIC_LOOP(
idx = index,
strict = singleImplicit(comp.var, comp.eqn),
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT), dummy, funcTree, index, VariablePointers.empty(), Pointer.create(0), systemType);

case StrongComponent.MULTI_COMPONENT()
then tearingNone(StrongComponent.ALGEBRAIC_LOOP(
idx = index,
strict = singleImplicit(List.first(comp.vars), comp.eqn), // this is wrong! need to take all vars
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT), dummy, funcTree, index, VariablePointers.empty(), Pointer.create(0), systemType);
case StrongComponent.SINGLE_COMPONENT() algorithm
new_comp := StrongComponent.ALGEBRAIC_LOOP(
idx = index,
strict = singleImplicit(comp.var, comp.eqn),
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT);
for func in funcs loop
(new_comp, dummy, funcTree, index) := func(new_comp, dummy, funcTree, index, VariablePointers.empty(), EquationPointers.empty(), Pointer.create(0), systemType);
end for;
then (new_comp, dummy, funcTree, index);

case StrongComponent.MULTI_COMPONENT() algorithm
new_comp := StrongComponent.ALGEBRAIC_LOOP(
idx = index,
strict = singleImplicit(List.first(comp.vars), comp.eqn), // this is wrong! need to take all vars
casual = NONE(),
linear = false,
mixed = false,
status = NBSolve.Status.IMPLICIT);
for func in funcs loop
(new_comp, dummy, funcTree, index) := func(new_comp, dummy, funcTree, index, VariablePointers.empty(), EquationPointers.empty(), Pointer.create(0), systemType);
end for;
then (new_comp, dummy, funcTree, index);

// do nothing otherwise
else (comp, dummy, funcTree, index);
Expand Down Expand Up @@ -234,7 +244,6 @@ protected
input list<Module.tearingInterface> funcs;
output list<System.System> new_systems = {};
input output FunctionTree funcTree;
input VariablePointers variables;
input Pointer<Integer> eq_index;
input System.SystemType systemType;
protected
Expand All @@ -251,7 +260,7 @@ protected
// each module has a list of functions that need to be applied
tmp := strongComponents[i];
for func in funcs loop
(tmp, full, funcTree, idx) := func(tmp, full, funcTree, idx, variables, eq_index, systemType);
(tmp, full, funcTree, idx) := func(tmp, full, funcTree, idx, syst.unknowns, syst.equations, eq_index, systemType);
end for;
// only update if it changed
if not referenceEq(tmp, strongComponents[i]) then
Expand Down Expand Up @@ -288,7 +297,6 @@ protected
comps = listArray(residual_comps),
funcTree = funcTree,
name = System.System.systemTypeString(systemType) + tag + intString(index));

strict.jac := jacobian;
comp.strict := strict;
if Flags.isSet(Flags.TEARING_DUMP) then
Expand Down Expand Up @@ -346,8 +354,9 @@ protected
Adjacency.Matrix adj;
Matching matching;
list<StrongComponent> inner_comps, residual_comps;
UnorderedMap<ComponentRef, Integer> v, e;
algorithm
//print("######## minimal ########\n");
print("######## minimal ########\n");
(comp, index) := match comp
case StrongComponent.ALGEBRAIC_LOOP(strict = strict) algorithm

Expand All @@ -356,8 +365,12 @@ protected
(cont_vars, disc_vars) := List.splitOnTrue(vars_lst, BVariable.isContinuous);
(cont_eqns, disc_eqns) := List.splitOnTrue(eqns_lst, Equation.isContinuous);

//print(List.toString(disc_vars, function BVariable.pointerToString()) + "\n");
//print(List.toString(disc_eqns, function Equation.pointerToString(str = "")) + "\n");
print(List.toString(disc_vars, function BVariable.pointerToString(), "", "", "\n", "") + "\n");
print(List.toString(disc_eqns, function Equation.pointerToString(str = ""), "", "", "\n", "") + "\n");
v := UnorderedMap.subSet(variables.map, list(BVariable.getVarName(var) for var in disc_vars));
e := UnorderedMap.subSet(equations.map, list(Equation.getEqnName(eqn) for eqn in disc_eqns));
(full, funcTree) := Adjacency.Matrix.refine(full, funcTree, v, e, variables, equations);


/*
Expand Down
1 change: 1 addition & 0 deletions OMCompiler/Compiler/NBackEnd/Modules/NBModule.mo
Expand Up @@ -291,6 +291,7 @@ public
input output FunctionTree funcTree "Function call bodies";
input output Integer index "current unique loop index";
input VariablePointers variables "all variables";
input EquationPointers equations "all equations";
input Pointer<Integer> eq_index "equation index";
input System.SystemType systemType = NBSystem.SystemType.ODE "system type";
end tearingInterface;
Expand Down
46 changes: 45 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Util/NBAdjacency.mo
Expand Up @@ -44,13 +44,15 @@ protected
import Dimension = NFDimension;
import Expression = NFExpression;
import FunctionTree = NFFlatten.FunctionTree;
import SimplifyExp = NFSimplifyExp;
import Subscript = NFSubscript;
import Type = NFType;
import Operator = NFOperator;
import Variable = NFVariable;

// NB imports
import Differentiate = NBDifferentiate;
import NBDifferentiate.{DifferentiationArguments, DifferentiationType};
import BEquation = NBEquation;
import NBEquation.{Equation, EquationAttributes, EquationPointers, Iterator, IfEquationBody, WhenEquationBody, WhenStatement};
import BVariable = NBVariable;
Expand Down Expand Up @@ -686,8 +688,9 @@ public
end for;
end if;
then full;

else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " expected types final, got type " + strictnessString(getStrictness(full)) + "."});
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " expected type full, got type " + strictnessString(getStrictness(full)) + "."});
then fail();
end match;

Expand All @@ -696,6 +699,47 @@ public
end if;
end expandFull;

function refine
"refines the solvability kind using differentiation
Note: only updates the solvabilites of the variables and equations from the maps"
input output Matrix full;
input output FunctionTree funcTree;
input UnorderedMap<ComponentRef, Integer> v "variables to refine";
input UnorderedMap<ComponentRef, Integer> e "equations to refine";
input VariablePointers vars "all variables";
input EquationPointers eqns "all equations";
algorithm
full := match full
local
ComponentRef eqn, var;
Integer eqn_idx, var_idx;
DifferentiationArguments diffArgs;
Pointer<Equation> eqn_ptr;
Expression exp;

case FULL() algorithm
for v_tpl in UnorderedMap.toList(v) loop
(var, var_idx) := v_tpl;
diffArgs := DifferentiationArguments.default(NBDifferentiate.DifferentiationType.SIMPLE, funcTree);
diffArgs.diffCref := var;
for e_tpl in UnorderedMap.toList(e) loop
(eqn, eqn_idx) := e_tpl;
eqn_ptr := EquationPointers.getEqnAt(eqns, eqn_idx);
// get the residual expression, differentiate and simplify it
exp := Equation.getResidualExp(Pointer.access(eqn_ptr));
(exp, diffArgs) := Differentiate.differentiateExpressionDump(exp, diffArgs, getInstanceName());
exp := SimplifyExp.simplifyDump(exp, true, getInstanceName());
print("the partial derivative by " + ComponentRef.toString(var) + " of equation\n" + Equation.pointerToString(eqn_ptr) + "\nis: " + Expression.toString(exp) + "\n");
end for;
end for;
then full;

else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " expected type full, got type " + strictnessString(getStrictness(full)) + "."});
then fail();
end match;
end refine;

function compress
"use after equations have been removed"
input output Matrix adj;
Expand Down

0 comments on commit 188b3c3

Please sign in to comment.