Skip to content

Commit

Permalink
[NB] add partial derivatives of functions (#12203)
Browse files Browse the repository at this point in the history
* [NB] add partial derivatives of functions

 - differentiate partial derivative functions while lowering

* [NB] add test for partial derivative

* [NB] add missing else cases to partial function diff
  • Loading branch information
kabdelhak committed Apr 9, 2024
1 parent 6016ef7 commit f391abb
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 18 deletions.
29 changes: 25 additions & 4 deletions OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Expand Up @@ -47,8 +47,11 @@ public
import NBStrongComponent.CountCollector;
import NBSystem;
import NBSystem.System;

protected
// Old Frontend imports
import Absyn.Path;
import AbsynUtil;

// New Frontend imports
import Algorithm = NFAlgorithm;
import BackendExtension = NFBackendExtension;
Expand All @@ -62,6 +65,7 @@ protected
import Expression = NFExpression;
import FEquation = NFEquation;
import FlatModel = NFFlatModel;
import NFFunction.Function;
import InstNode = NFInstNode.InstNode;
import Prefixes = NFPrefixes;
import Statement = NFStatement;
Expand All @@ -74,8 +78,9 @@ protected
import BackendDAE = NBackendDAE;
import Bindings = NBBindings;
import Causalize = NBCausalize;
import DetectStates = NBDetectStates;
import DAEMode = NBDAEMode;
import DetectStates = NBDetectStates;
import Differentiate = NBDifferentiate;
import FunctionAlias = NBFunctionAlias;
import Initialization = NBInitialization;
import Inline = NBInline;
Expand Down Expand Up @@ -210,11 +215,11 @@ public
VarData variableData;
EqData equationData;
Events.EventInfo eventInfo = Events.EventInfo.empty();
UnorderedMap<Path, Function> functions;
algorithm
// expand records to its children. Put behind flag?
variableData := lowerVariableData(flatModel.variables);
(equationData, variableData) := lowerEquationData(flatModel.equations, flatModel.algorithms, flatModel.initialEquations, flatModel.initialAlgorithms, variableData);
bdae := MAIN({}, {}, {}, {}, {}, NONE(), NONE(), variableData, equationData, eventInfo, funcTree);
bdae := MAIN({}, {}, {}, {}, {}, NONE(), NONE(), variableData, equationData, eventInfo, lowerFunctions(funcTree));
end lower;

function main
Expand Down Expand Up @@ -1313,6 +1318,22 @@ public
end match;
end lowerIteratorExp;
function lowerFunctions
input output FunctionTree funcTree;
protected
// ToDo: replace all function trees with this UnorderedMap
UnorderedMap<Path, Function> functions = UnorderedMap.new<Function>(AbsynUtil.pathHash, AbsynUtil.pathEqual);
protected
Path path;
Function fn;
algorithm
for tpl in FunctionTree.toList(funcTree) loop
(path, fn) := tpl;
(fn, funcTree) := Differentiate.resolvePartialDerivatives(fn, funcTree);
UnorderedMap.add(path, fn, functions);
end for;
end lowerFunctions;
function backenddaeinfo
input BackendDAE bdae;
algorithm
Expand Down
171 changes: 161 additions & 10 deletions OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo
Expand Up @@ -44,9 +44,11 @@ public

// NF imports
import Algorithm = NFAlgorithm;
import Binding = NFBinding;
import BuiltinFuncs = NFBuiltinFuncs;
import Call = NFCall;
import Class = NFClass;
import Component = NFComponent;
import ComponentRef = NFComponentRef;
import Expression = NFExpression;
import NFInstNode.{InstNode, CachedData};
Expand All @@ -66,6 +68,7 @@ public
import NBEquation.{Equation, EquationAttributes, EquationPointer, EquationPointers, IfEquationBody, WhenEquationBody, WhenStatement};
import NBVariable.{VariablePointer};
import BVariable = NBVariable;
import Replacements = NBReplacements;
import StrongComponent = NBStrongComponent;

// Util imports
Expand Down Expand Up @@ -555,7 +558,6 @@ public
Expression res;
UnorderedMap<ComponentRef,ComponentRef> jacobianHT;


// -------------------------------------
// EMPTY and WILD crefs do nothing
// -------------------------------------
Expand Down Expand Up @@ -828,7 +830,7 @@ public

end differentiateCall;

protected function differentiateBuiltinCall
function differentiateBuiltinCall
"This function differentiates built-in call expressions with respect to a given variable.
Also creates and multiplies inner derivatives."
input String name;
Expand Down Expand Up @@ -1238,23 +1240,31 @@ public
Function dummy_func;
CachedData cachedData;
String der_func_name;
list<InstNode> local_outputs;

case der_func as Function.FUNCTION(node = node as InstNode.CLASS_NODE(cls = cls)) algorithm
new_cls := match Pointer.access(cls)
case new_cls as Class.INSTANCED_CLASS(sections = sections as Sections.SECTIONS()) algorithm

// differentiate interface arguments
der_func.inputs := differentiateFunctionInterfaceNodes(der_func.inputs, interface_map, diff_map, true);
der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, true);
der_func.locals := listAppend(der_func.locals, list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs));
der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, false);
// prepare outputs that become locals
local_outputs := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs);
local_outputs := list(InstNode.protect(node) for node in local_outputs);

// prepare differentiation arguments
funcDiffArgs := DifferentiationArguments.default();
funcDiffArgs.diffType := DifferentiationType.FUNCTION;
funcDiffArgs.funcTree := diffArguments.funcTree;
createInterfaceDerivatives(der_func.inputs, interface_map, diff_map);
createInterfaceDerivatives(der_func.locals, interface_map, diff_map);
createInterfaceDerivatives(der_func.outputs, interface_map, diff_map);
funcDiffArgs.jacobianHT := SOME(diff_map);

// differentiate interface arguments
der_func.inputs := differentiateFunctionInterfaceNodes(der_func.inputs, interface_map, diff_map, funcDiffArgs, true);
der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, funcDiffArgs, true);
der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, funcDiffArgs, false);

der_func.locals := listAppend(der_func.locals, local_outputs);

// create "fake" function with correct interface to have the interface
// in the case of recursive differentiation (e.g. function calls itself)
dummy_func := func;
Expand Down Expand Up @@ -1328,22 +1338,149 @@ public
input output list<InstNode> interface_nodes;
input UnorderedMap<String, Boolean> interface_map;
input UnorderedMap<ComponentRef, ComponentRef> diff_map;
input output DifferentiationArguments diffArgs;
input Boolean keepOld;
protected
list<InstNode> new_nodes;
ComponentRef cref, diff_cref;
InstNode comp_node;
Component comp;
Binding binding;
algorithm
new_nodes := if keepOld then listReverse(interface_nodes) else {};
interface_nodes := list(node for node guard(not UnorderedMap.contains(InstNode.name(node), interface_map)) in interface_nodes);
for node in interface_nodes loop
cref := ComponentRef.fromNode(node, InstNode.getType(node));
diff_cref := BVariable.makeFDerVar(cref);
UnorderedMap.add(cref, diff_cref, diff_map);
diff_cref := UnorderedMap.getSafe(cref, diff_map, sourceInfo());
diff_cref := match diff_cref
case ComponentRef.CREF(node = comp_node as InstNode.COMPONENT_NODE()) algorithm
// differentiate bindings
comp := Pointer.access(comp_node.component);
comp := match comp
case comp as Component.COMPONENT() algorithm
(binding, diffArgs) := differentiateBinding(comp.binding, diffArgs);
comp.binding := binding;
then comp;
else comp;
end match;
comp_node.component := Pointer.create(comp);
diff_cref.node := comp_node;
then diff_cref;
else diff_cref;
end match;
new_nodes := ComponentRef.node(diff_cref) :: new_nodes;
end for;
interface_nodes := listReverse(new_nodes);
end differentiateFunctionInterfaceNodes;

function createInterfaceDerivatives
input list<InstNode> interface_nodes;
input UnorderedMap<String, Boolean> interface_map;
input UnorderedMap<ComponentRef, ComponentRef> diff_map;
protected
ComponentRef cref, diff_cref;
list<InstNode> n;
algorithm
n := list(node for node guard(not UnorderedMap.contains(InstNode.name(node), interface_map)) in interface_nodes);
for node in n loop
cref := ComponentRef.fromNode(node, InstNode.getType(node));
diff_cref := BVariable.makeFDerVar(cref);
UnorderedMap.add(cref, diff_cref, diff_map);
end for;
end createInterfaceDerivatives;

function resolvePartialDerivatives
input output Function func;
input output FunctionTree funcTree;
protected
Function der_func;
InstNode node;
Pointer<Class> cls, tmp_cls;
Class new_cls, wrap_cls;
Sections sections;
UnorderedMap<ComponentRef, ComponentRef> diff_map = UnorderedMap.new<ComponentRef>(ComponentRef.hash, ComponentRef.isEqual);
UnorderedMap<String, Boolean> interface_map;
DifferentiationArguments diffArgs = DifferentiationArguments.default();
list<Algorithm> algorithms;
CachedData cachedData;
InstNode diffVar;
ComponentRef diffCref;
list<InstNode> local_outputs;
Boolean changed = false;
algorithm
func := match func
case der_func as Function.FUNCTION(node = InstNode.CLASS_NODE(cls = cls)) algorithm
wrap_cls := Pointer.access(cls);
new_cls := match wrap_cls
case wrap_cls as Class.TYPED_DERIVED(baseClass = node as InstNode.CLASS_NODE(cls = tmp_cls)) algorithm
new_cls := match Pointer.access(tmp_cls)
case new_cls as Class.INSTANCED_CLASS(sections = sections as Sections.SECTIONS(algorithms = algorithms)) algorithm
// prepare differentiation arguments
diffArgs.diffType := DifferentiationType.FUNCTION;
diffArgs.funcTree := funcTree;

interface_map := UnorderedMap.fromLists(list(InstNode.name(var) for var in der_func.inputs), List.fill(false, listLength(der_func.inputs)), stringHashDjb2, stringEqual);

// add all differentiated inputs to the interface map
for var in List.getAtIndexLst(der_func.inputs, der_func.derivedInputs) loop
UnorderedMap.remove(InstNode.name(var), interface_map);

// prepare outputs that become locals
local_outputs := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs);
local_outputs := list(InstNode.protect(node) for node in local_outputs);

// differentiate interface arguments
createInterfaceDerivatives({var}, interface_map, diff_map);
createInterfaceDerivatives(der_func.locals, interface_map, diff_map);
createInterfaceDerivatives(der_func.outputs, interface_map, diff_map);
diffArgs.jacobianHT := SOME(diff_map);

der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, diffArgs, true);
der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, diffArgs, false);

diffCref := UnorderedMap.getSafe(ComponentRef.fromNode(var, InstNode.getType(var)), diff_map, sourceInfo());
der_func.locals := listAppend(der_func.locals, local_outputs);

// differentiate function statements
(algorithms, diffArgs) := List.mapFold(algorithms, differentiateAlgorithm, diffArgs);
algorithms := Algorithm.mapExpList(algorithms, function Replacements.single(old = Expression.fromCref(diffCref), new = Expression.makeOne(ComponentRef.getSubscriptedType(diffCref))));

UnorderedMap.add(InstNode.name(var), false, interface_map);
end for;

// add them to new node
sections.algorithms := algorithms;
new_cls.sections := sections;
new_cls.ty := wrap_cls.ty;
new_cls.restriction := wrap_cls.restriction;
node.cls := Pointer.create(new_cls);
cachedData := CachedData.FUNCTION({der_func}, true, false);
der_func.node := InstNode.setFuncCache(node, cachedData);
der_func.derivatives := {};
der_func.derivedInputs := {};

changed := true;
then new_cls;

else wrap_cls;
end match;
then new_cls;
else wrap_cls;
end match;

if changed then
if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then
print("\n[BEFORE] " + Function.toFlatString(func) + "\n");
print("\n[AFTER ] " + Function.toFlatString(der_func) + "\n\n");
end if;
funcTree := FunctionTreeImpl.add(funcTree, der_func.path, der_func, FunctionTreeImpl.addConflictReplace);
end if;
then der_func;

else func;
end match;
end resolvePartialDerivatives;

function differentiateAlgorithm
input output Algorithm alg;
input output DifferentiationArguments diffArguments;
Expand Down Expand Up @@ -1743,6 +1880,20 @@ public
end match;
end differentiateEquationAttributes;

function differentiateBinding
input output Binding binding;
input output DifferentiationArguments diffArgs;
protected
Option<Expression> opt_exp;
Expression exp;
algorithm
opt_exp := Binding.getExpOpt(binding);
if Util.isSome(opt_exp) then
(exp, diffArgs) := differentiateExpression(Util.getOption(opt_exp), diffArgs);
binding := Binding.setExp(exp, binding);
end if;
end differentiateBinding;

protected
function minusOne
input output Expression exp;
Expand Down
20 changes: 20 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo
Expand Up @@ -1717,6 +1717,26 @@ uniontype InstNode
end match;
end protectComponent;

function protect
input output InstNode node;
algorithm
() := match node
case COMPONENT_NODE(visibility = Visibility.PUBLIC)
algorithm
node.visibility := Visibility.PROTECTED;
then
();

case CLASS_NODE(visibility = Visibility.PUBLIC)
algorithm
node.visibility := Visibility.PROTECTED;
then
();

else ();
end match;
end protect;

function isEncapsulated
input InstNode node;
output Boolean enc;
Expand Down
6 changes: 2 additions & 4 deletions OMCompiler/Compiler/Util/List.mo
Expand Up @@ -733,14 +733,12 @@ public function getAtIndexLst<T>
input list<T> lst;
input list<Integer> positions;
input Boolean zeroBased = false;
output list<T> olst = {};
output list<T> olst;
protected
array<T> arr = listArray(lst);
Integer shift = if zeroBased then 1 else 0;
algorithm
for pos in listReverse(positions) loop
olst := arr[pos+shift] :: olst;
end for;
olst := list(arr[pos+shift] for pos in positions);
end getAtIndexLst;

public function firstN<T>
Expand Down
1 change: 1 addition & 0 deletions testsuite/simulation/modelica/NBackend/functions/Makefile
Expand Up @@ -4,6 +4,7 @@ TESTFILES = \
builtin_functions.mos\
function_annotation_der.mos\
function_diff.mos\
function_partial_der.mos\

# test that currently fail. Move up when fixed.
# Run make failingtest
Expand Down

0 comments on commit f391abb

Please sign in to comment.