diff --git a/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo b/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo index 3361cdc0339..67438570590 100644 --- a/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo +++ b/OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo @@ -47,8 +47,11 @@ public import NBStrongComponent.CountCollector; import NBSystem; import NBSystem.System; - protected + // Old Frontend imports + import Absyn.Path; + import AbsynUtil; + // New Frontend imports import Algorithm = NFAlgorithm; import BackendExtension = NFBackendExtension; @@ -62,6 +65,7 @@ protected import Expression = NFExpression; import FEquation = NFEquation; import FlatModel = NFFlatModel; + import NFFunction.Function; import InstNode = NFInstNode.InstNode; import Prefixes = NFPrefixes; import Statement = NFStatement; @@ -74,8 +78,9 @@ protected import BackendDAE = NBackendDAE; import Bindings = NBBindings; import Causalize = NBCausalize; - import DetectStates = NBDetectStates; import DAEMode = NBDAEMode; + import DetectStates = NBDetectStates; + import Differentiate = NBDifferentiate; import FunctionAlias = NBFunctionAlias; import Initialization = NBInitialization; import Inline = NBInline; @@ -210,11 +215,11 @@ public VarData variableData; EqData equationData; Events.EventInfo eventInfo = Events.EventInfo.empty(); + UnorderedMap functions; algorithm - // expand records to its children. Put behind flag? variableData := lowerVariableData(flatModel.variables); (equationData, variableData) := lowerEquationData(flatModel.equations, flatModel.algorithms, flatModel.initialEquations, flatModel.initialAlgorithms, variableData); - bdae := MAIN({}, {}, {}, {}, {}, NONE(), NONE(), variableData, equationData, eventInfo, funcTree); + bdae := MAIN({}, {}, {}, {}, {}, NONE(), NONE(), variableData, equationData, eventInfo, lowerFunctions(funcTree)); end lower; function main @@ -1313,6 +1318,22 @@ public end match; end lowerIteratorExp; + function lowerFunctions + input output FunctionTree funcTree; + protected + // ToDo: replace all function trees with this UnorderedMap + UnorderedMap functions = UnorderedMap.new(AbsynUtil.pathHash, AbsynUtil.pathEqual); + protected + Path path; + Function fn; + algorithm + for tpl in FunctionTree.toList(funcTree) loop + (path, fn) := tpl; + (fn, funcTree) := Differentiate.resolvePartialDerivatives(fn, funcTree); + UnorderedMap.add(path, fn, functions); + end for; + end lowerFunctions; + function backenddaeinfo input BackendDAE bdae; algorithm diff --git a/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo b/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo index 302c51ff32d..0a4172db211 100644 --- a/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo +++ b/OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo @@ -44,9 +44,11 @@ public // NF imports import Algorithm = NFAlgorithm; + import Binding = NFBinding; import BuiltinFuncs = NFBuiltinFuncs; import Call = NFCall; import Class = NFClass; + import Component = NFComponent; import ComponentRef = NFComponentRef; import Expression = NFExpression; import NFInstNode.{InstNode, CachedData}; @@ -66,6 +68,7 @@ public import NBEquation.{Equation, EquationAttributes, EquationPointer, EquationPointers, IfEquationBody, WhenEquationBody, WhenStatement}; import NBVariable.{VariablePointer}; import BVariable = NBVariable; + import Replacements = NBReplacements; import StrongComponent = NBStrongComponent; // Util imports @@ -555,7 +558,6 @@ public Expression res; UnorderedMap jacobianHT; - // ------------------------------------- // EMPTY and WILD crefs do nothing // ------------------------------------- @@ -828,7 +830,7 @@ public end differentiateCall; - protected function differentiateBuiltinCall + function differentiateBuiltinCall "This function differentiates built-in call expressions with respect to a given variable. Also creates and multiplies inner derivatives." input String name; @@ -1238,23 +1240,31 @@ public Function dummy_func; CachedData cachedData; String der_func_name; + list local_outputs; case der_func as Function.FUNCTION(node = node as InstNode.CLASS_NODE(cls = cls)) algorithm new_cls := match Pointer.access(cls) case new_cls as Class.INSTANCED_CLASS(sections = sections as Sections.SECTIONS()) algorithm - - // differentiate interface arguments - der_func.inputs := differentiateFunctionInterfaceNodes(der_func.inputs, interface_map, diff_map, true); - der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, true); - der_func.locals := listAppend(der_func.locals, list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs)); - der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, false); + // prepare outputs that become locals + local_outputs := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs); + local_outputs := list(InstNode.protect(node) for node in local_outputs); // prepare differentiation arguments funcDiffArgs := DifferentiationArguments.default(); funcDiffArgs.diffType := DifferentiationType.FUNCTION; funcDiffArgs.funcTree := diffArguments.funcTree; + createInterfaceDerivatives(der_func.inputs, interface_map, diff_map); + createInterfaceDerivatives(der_func.locals, interface_map, diff_map); + createInterfaceDerivatives(der_func.outputs, interface_map, diff_map); funcDiffArgs.jacobianHT := SOME(diff_map); + // differentiate interface arguments + der_func.inputs := differentiateFunctionInterfaceNodes(der_func.inputs, interface_map, diff_map, funcDiffArgs, true); + der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, funcDiffArgs, true); + der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, funcDiffArgs, false); + + der_func.locals := listAppend(der_func.locals, local_outputs); + // create "fake" function with correct interface to have the interface // in the case of recursive differentiation (e.g. function calls itself) dummy_func := func; @@ -1328,22 +1338,149 @@ public input output list interface_nodes; input UnorderedMap interface_map; input UnorderedMap diff_map; + input output DifferentiationArguments diffArgs; input Boolean keepOld; protected list new_nodes; ComponentRef cref, diff_cref; + InstNode comp_node; + Component comp; + Binding binding; algorithm new_nodes := if keepOld then listReverse(interface_nodes) else {}; interface_nodes := list(node for node guard(not UnorderedMap.contains(InstNode.name(node), interface_map)) in interface_nodes); for node in interface_nodes loop cref := ComponentRef.fromNode(node, InstNode.getType(node)); - diff_cref := BVariable.makeFDerVar(cref); - UnorderedMap.add(cref, diff_cref, diff_map); + diff_cref := UnorderedMap.getSafe(cref, diff_map, sourceInfo()); + diff_cref := match diff_cref + case ComponentRef.CREF(node = comp_node as InstNode.COMPONENT_NODE()) algorithm + // differentiate bindings + comp := Pointer.access(comp_node.component); + comp := match comp + case comp as Component.COMPONENT() algorithm + (binding, diffArgs) := differentiateBinding(comp.binding, diffArgs); + comp.binding := binding; + then comp; + else comp; + end match; + comp_node.component := Pointer.create(comp); + diff_cref.node := comp_node; + then diff_cref; + else diff_cref; + end match; new_nodes := ComponentRef.node(diff_cref) :: new_nodes; end for; interface_nodes := listReverse(new_nodes); end differentiateFunctionInterfaceNodes; + function createInterfaceDerivatives + input list interface_nodes; + input UnorderedMap interface_map; + input UnorderedMap diff_map; + protected + ComponentRef cref, diff_cref; + list n; + algorithm + n := list(node for node guard(not UnorderedMap.contains(InstNode.name(node), interface_map)) in interface_nodes); + for node in n loop + cref := ComponentRef.fromNode(node, InstNode.getType(node)); + diff_cref := BVariable.makeFDerVar(cref); + UnorderedMap.add(cref, diff_cref, diff_map); + end for; + end createInterfaceDerivatives; + + function resolvePartialDerivatives + input output Function func; + input output FunctionTree funcTree; + protected + Function der_func; + InstNode node; + Pointer cls, tmp_cls; + Class new_cls, wrap_cls; + Sections sections; + UnorderedMap diff_map = UnorderedMap.new(ComponentRef.hash, ComponentRef.isEqual); + UnorderedMap interface_map; + DifferentiationArguments diffArgs = DifferentiationArguments.default(); + list algorithms; + CachedData cachedData; + InstNode diffVar; + ComponentRef diffCref; + list local_outputs; + Boolean changed = false; + algorithm + func := match func + case der_func as Function.FUNCTION(node = InstNode.CLASS_NODE(cls = cls)) algorithm + wrap_cls := Pointer.access(cls); + new_cls := match wrap_cls + case wrap_cls as Class.TYPED_DERIVED(baseClass = node as InstNode.CLASS_NODE(cls = tmp_cls)) algorithm + new_cls := match Pointer.access(tmp_cls) + case new_cls as Class.INSTANCED_CLASS(sections = sections as Sections.SECTIONS(algorithms = algorithms)) algorithm + // prepare differentiation arguments + diffArgs.diffType := DifferentiationType.FUNCTION; + diffArgs.funcTree := funcTree; + + interface_map := UnorderedMap.fromLists(list(InstNode.name(var) for var in der_func.inputs), List.fill(false, listLength(der_func.inputs)), stringHashDjb2, stringEqual); + + // add all differentiated inputs to the interface map + for var in List.getAtIndexLst(der_func.inputs, der_func.derivedInputs) loop + UnorderedMap.remove(InstNode.name(var), interface_map); + + // prepare outputs that become locals + local_outputs := list(InstNode.setComponentDirection(NFPrefixes.Direction.NONE, node) for node in der_func.outputs); + local_outputs := list(InstNode.protect(node) for node in local_outputs); + + // differentiate interface arguments + createInterfaceDerivatives({var}, interface_map, diff_map); + createInterfaceDerivatives(der_func.locals, interface_map, diff_map); + createInterfaceDerivatives(der_func.outputs, interface_map, diff_map); + diffArgs.jacobianHT := SOME(diff_map); + + der_func.locals := differentiateFunctionInterfaceNodes(der_func.locals, interface_map, diff_map, diffArgs, true); + der_func.outputs := differentiateFunctionInterfaceNodes(der_func.outputs, interface_map, diff_map, diffArgs, false); + + diffCref := UnorderedMap.getSafe(ComponentRef.fromNode(var, InstNode.getType(var)), diff_map, sourceInfo()); + der_func.locals := listAppend(der_func.locals, local_outputs); + + // differentiate function statements + (algorithms, diffArgs) := List.mapFold(algorithms, differentiateAlgorithm, diffArgs); + algorithms := Algorithm.mapExpList(algorithms, function Replacements.single(old = Expression.fromCref(diffCref), new = Expression.makeOne(ComponentRef.getSubscriptedType(diffCref)))); + + UnorderedMap.add(InstNode.name(var), false, interface_map); + end for; + + // add them to new node + sections.algorithms := algorithms; + new_cls.sections := sections; + new_cls.ty := wrap_cls.ty; + new_cls.restriction := wrap_cls.restriction; + node.cls := Pointer.create(new_cls); + cachedData := CachedData.FUNCTION({der_func}, true, false); + der_func.node := InstNode.setFuncCache(node, cachedData); + der_func.derivatives := {}; + der_func.derivedInputs := {}; + + changed := true; + then new_cls; + + else wrap_cls; + end match; + then new_cls; + else wrap_cls; + end match; + + if changed then + if Flags.isSet(Flags.DEBUG_DIFFERENTIATION) then + print("\n[BEFORE] " + Function.toFlatString(func) + "\n"); + print("\n[AFTER ] " + Function.toFlatString(der_func) + "\n\n"); + end if; + funcTree := FunctionTreeImpl.add(funcTree, der_func.path, der_func, FunctionTreeImpl.addConflictReplace); + end if; + then der_func; + + else func; + end match; + end resolvePartialDerivatives; + function differentiateAlgorithm input output Algorithm alg; input output DifferentiationArguments diffArguments; @@ -1743,6 +1880,20 @@ public end match; end differentiateEquationAttributes; + function differentiateBinding + input output Binding binding; + input output DifferentiationArguments diffArgs; + protected + Option opt_exp; + Expression exp; + algorithm + opt_exp := Binding.getExpOpt(binding); + if Util.isSome(opt_exp) then + (exp, diffArgs) := differentiateExpression(Util.getOption(opt_exp), diffArgs); + binding := Binding.setExp(exp, binding); + end if; + end differentiateBinding; + protected function minusOne input output Expression exp; diff --git a/OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo b/OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo index cddad9db460..7c990978b11 100644 --- a/OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo +++ b/OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo @@ -1717,6 +1717,26 @@ uniontype InstNode end match; end protectComponent; + function protect + input output InstNode node; + algorithm + () := match node + case COMPONENT_NODE(visibility = Visibility.PUBLIC) + algorithm + node.visibility := Visibility.PROTECTED; + then + (); + + case CLASS_NODE(visibility = Visibility.PUBLIC) + algorithm + node.visibility := Visibility.PROTECTED; + then + (); + + else (); + end match; + end protect; + function isEncapsulated input InstNode node; output Boolean enc; diff --git a/OMCompiler/Compiler/Util/List.mo b/OMCompiler/Compiler/Util/List.mo index 1702d38b615..7ab95625b61 100644 --- a/OMCompiler/Compiler/Util/List.mo +++ b/OMCompiler/Compiler/Util/List.mo @@ -733,14 +733,12 @@ public function getAtIndexLst input list lst; input list positions; input Boolean zeroBased = false; - output list olst = {}; + output list olst; protected array arr = listArray(lst); Integer shift = if zeroBased then 1 else 0; algorithm - for pos in listReverse(positions) loop - olst := arr[pos+shift] :: olst; - end for; + olst := list(arr[pos+shift] for pos in positions); end getAtIndexLst; public function firstN diff --git a/testsuite/simulation/modelica/NBackend/functions/Makefile b/testsuite/simulation/modelica/NBackend/functions/Makefile index 469085a15cc..a557152b4da 100644 --- a/testsuite/simulation/modelica/NBackend/functions/Makefile +++ b/testsuite/simulation/modelica/NBackend/functions/Makefile @@ -4,6 +4,7 @@ TESTFILES = \ builtin_functions.mos\ function_annotation_der.mos\ function_diff.mos\ +function_partial_der.mos\ # test that currently fail. Move up when fixed. # Run make failingtest diff --git a/testsuite/simulation/modelica/NBackend/functions/function_partial_der.mos b/testsuite/simulation/modelica/NBackend/functions/function_partial_der.mos new file mode 100644 index 00000000000..45624b6156d --- /dev/null +++ b/testsuite/simulation/modelica/NBackend/functions/function_partial_der.mos @@ -0,0 +1,53 @@ +// name: function_partial_der +// keywords: NewBackend +// status: correct + +loadString(" +model function_partial_der + function sinwave + input Real x; + output Real y; + protected + Real b = 1.0 \"tests the differentiation of locals\"; + algorithm + y := sin(x) * b; + end sinwave; + function coswave = der(sinwave,x); + Real x, cosVal, realcos; + equation + x = time; + cosVal = coswave(x); + realcos = cos(x); +end function_partial_der; +"); getErrorString(); + +setCommandLineOptions("--newBackend -d=debugDifferentiation"); getErrorString(); +simulate(function_partial_der); getErrorString(); +// Result: +// true +// "" +// true +// "" +// +// [BEFORE] function 'function_partial_der.coswave' = der('function_partial_der.sinwave', x) +// +// [AFTER ] function 'function_partial_der.coswave' +// input Real 'x'; +// output Real '$fDER_y'; +// Real 'b' = 1.0; +// Real '$fDER_b' = 0.0; +// Real 'y'; +// algorithm +// '$fDER_y' := sin('x') * '$fDER_b' + (cos('x') * 1.0) * 'b'; +// 'y' := sin('x') * 'b'; +// end 'function_partial_der.coswave' +// +// record SimulationResult +// resultFile = "function_partial_der_res.mat", +// simulationOptions = "startTime = 0.0, stopTime = 1.0, numberOfIntervals = 500, tolerance = 1e-6, method = 'dassl', fileNamePrefix = 'function_partial_der', options = '', outputFormat = 'mat', variableFilter = '.*', cflags = '', simflags = ''", +// messages = "LOG_SUCCESS | info | The initialization finished successfully without homotopy method. +// LOG_SUCCESS | info | The simulation finished successfully. +// " +// end SimulationResult; +// "" +// endResult