Skip to content

Commit

Permalink
[NB] Fix derivatives of some builtin functions (#8643)
Browse files Browse the repository at this point in the history
The function values of abs, mod and rem are indeed not discrete if
their arguments are not, so their derivatives are not zero.
  • Loading branch information
phannebohm committed Mar 6, 2022
1 parent f2b1849 commit 39088fb
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 37 deletions.
72 changes: 51 additions & 21 deletions OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo
Expand Up @@ -49,6 +49,7 @@ public
import NFFunction.Function;
import NFFlatten.{FunctionTree, FunctionTreeImpl};
import Operator = NFOperator;
import Prefixes = NFPrefixes;
import SimplifyExp = NFSimplifyExp;
import Type = NFType;
import NFPrefixes.Variability;
Expand Down Expand Up @@ -631,7 +632,6 @@ public
Expression ret;
Boolean has_derviative_annotation = false;
Call call, der_call;
String name;
Option<Function> func_opt;
list<Function> derivatives;
Function func, der_func;
Expand All @@ -640,8 +640,7 @@ public

// builtin functions
case Expression.CALL(call = call as Call.TYPED_CALL()) guard(Function.isBuiltin(call.fn)) algorithm
name := AbsynUtil.pathString(Function.nameConsiderBuiltin(call.fn));
ret := differentiateBuiltinCall(name, exp, diffArguments);
ret := differentiateBuiltinCall(AbsynUtil.pathString(Function.nameConsiderBuiltin(call.fn)), exp, diffArguments);
then (ret, diffArguments);

// user defined functions
Expand All @@ -662,6 +661,7 @@ public
then Expression.CALL(Call.makeTypedCall(der_func, listAppend(call.arguments, arguments), call.var, call.purity));

// ERROR - more than one derivative of order 1 defined
// TODO pick first one according to MLS 3.5 section 12.7.1
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because there were " + intString(listLength(derivatives))
+ " derivatives of order 1 (expected is exactly one).\n Derivatives:" + List.toString(derivatives, function Function.signatureString(printTypes = true), "", "", "\n", "")});
Expand Down Expand Up @@ -705,7 +705,7 @@ public
Expression ret, ret1, ret2, arg1, arg2, diffArg1, diffArg2;

// Builtin function call with one argument
// df/dx = df/dy * dy/dx
// df(y)/dx = df/dy * dy/dx
case (Expression.CALL()) guard(listLength(Call.arguments(exp.call)) == 1)
algorithm
// differentiate the call
Expand Down Expand Up @@ -762,13 +762,20 @@ public
local
Expression ret;

// all these are integer based and therefore zero
case ("abs") then Expression.makeZero(Type.INTEGER());
case ("sign") then Expression.makeZero(Type.INTEGER());
case ("ceil") then Expression.makeZero(Type.INTEGER());
case ("floor") then Expression.makeZero(Type.INTEGER());
// all these have integer values and therefore zero derivative
case ("sign") then Expression.makeZero(Expression.typeOf(arg));
case ("ceil") then Expression.makeZero(Type.REAL());
case ("floor") then Expression.makeZero(Type.REAL());
case ("integer") then Expression.makeZero(Type.INTEGER());

// abs(arg) -> sign(arg)
case ("abs") then Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.SIGN_REAL,
args = {arg},
variability = Expression.variability(arg),
purity = NFPrefixes.Purity.PURE
));

// sqrt(arg) -> 0.5/arg^(0.5)
case ("sqrt") algorithm
ret := Expression.BINARY(arg, powOp, Expression.REAL(0.5)); // arg^0.5
Expand All @@ -784,7 +791,7 @@ public
));

// cos(arg) -> -sin(arg)
case("cos") then Expression.negate(Expression.CALL(Call.makeTypedCall(
case ("cos") then Expression.negate(Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.SIN_REAL,
args = {arg},
variability = Expression.variability(arg),
Expand All @@ -793,7 +800,7 @@ public

// tan(arg) -> 1/cos(arg)^2
// kabdelhak: ToDo - investigate numerical properties: 1+tan(arg)^2 maybe better?
case("tan") algorithm
case ("tan") algorithm
ret := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.COS_REAL,
args = {arg},
Expand All @@ -804,23 +811,23 @@ public
then ret;

// asin(arg) -> 1/sqrt(1-arg^2)
case("asin") algorithm
case ("asin") algorithm
ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0)); // arg^2
ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp); // 1-arg^2
ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5)); // sqrt(1-arg^2)
ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp); // 1/sqrt(1-arg^2)
then ret;

// acos(arg) -> -1/sqrt(1-arg^2)
case("acos") algorithm
case ("acos") algorithm
ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0)); // arg^2
ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, addOp); // 1-arg^2
ret := Expression.BINARY(ret, powOp, Expression.REAL(0.5)); // sqrt(1-arg^2)
ret := Expression.MULTARY({Expression.REAL(-1.0)}, {ret}, mulOp); // -1/sqrt(1-arg^2)
then ret;

// atan(arg) -> 1/(1+arg^2)
case("atan") algorithm
case ("atan") algorithm
ret := Expression.BINARY(arg, powOp, Expression.REAL(2.0)); // arg^2
ret := Expression.MULTARY({Expression.REAL(1.0), ret}, {}, addOp);// 1+arg^2
ret := Expression.MULTARY({Expression.REAL(1.0)}, {ret}, mulOp); // 1/(1+arg^2)
Expand All @@ -835,15 +842,15 @@ public
));

// cosh(arg) -> sinh(arg)
case("cosh") then Expression.CALL(Call.makeTypedCall(
case ("cosh") then Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.SINH_REAL,
args = {arg},
variability = Expression.variability(arg),
purity = NFPrefixes.Purity.PURE
));

// tanh(arg) -> 1-tanh(arg)^2
case("tanh") algorithm
case ("tanh") algorithm
ret := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.TANH_REAL,
args = {arg},
Expand All @@ -869,9 +876,9 @@ public
ret := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.LOG_REAL,
args = {Expression.REAL(10.0)},
variability = Expression.variability(arg),
variability = Variability.CONSTANT,
purity = NFPrefixes.Purity.PURE)); // log(10)
ret := Expression.MULTARY({Expression.REAL(1.0)}, {arg, ret}, mulOp); // 1/arg*log(10)
ret := Expression.MULTARY({Expression.REAL(1.0)}, {arg, ret}, mulOp); // 1/(arg*log(10))
then ret;

else algorithm
Expand All @@ -898,10 +905,33 @@ public
local
Expression exp1, exp2, ret1, ret2;

// all these are integer based and therefore zero
// div(arg1, arg2) truncates the fractional part of arg1/arg2 so it has discrete values
// therefore it has zero derivative where it's defined
case ("div") then (Expression.makeZero(Type.INTEGER()), Expression.makeZero(Type.INTEGER()));
case ("mod") then (Expression.makeZero(Type.INTEGER()), Expression.makeZero(Type.INTEGER()));
case ("rem") then (Expression.makeZero(Type.INTEGER()), Expression.makeZero(Type.INTEGER()));

// d/darg1 mod(arg1, arg2) -> 1
// d/darg2 mod(arg1, arg2) -> -floor(arg1/arg2)
case ("mod") algorithm
exp2 := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.FLOOR,
args = {Expression.MULTARY({arg1}, {arg2}, mulOp)}, // arg1/arg2
variability = Prefixes.variabilityMax(Expression.variability(arg1), Expression.variability(arg2)),
purity = NFPrefixes.Purity.PURE
)); // floor(arg1/arg2)
ret2 := Expression.negate(exp2); // -floor(arg1/arg2)
then (Expression.makeOne(Type.REAL()), ret2);

// d/darg1 rem(arg1, arg2) -> 1
// d/darg2 rem(arg1, arg2) -> -div(arg1, arg2)
case ("rem") algorithm
exp2 := Expression.CALL(Call.makeTypedCall(
fn = NFBuiltinFuncs.DIV_REAL,
args = {arg1, arg2},
variability = Prefixes.variabilityMax(Expression.variability(arg1), Expression.variability(arg2)),
purity = NFPrefixes.Purity.PURE
)); // div(arg1, arg2)
ret2 := Expression.negate(exp2); // -div(arg1, arg2)
then (Expression.makeOne(Type.REAL()), ret2);

// d/darg1 atan2(arg1, arg2) -> -arg2/(arg1^2+arg2^2)
// d/darg2 atan2(arg1, arg2) -> arg1/(arg1^2+arg2^2)
Expand Down
30 changes: 20 additions & 10 deletions OMCompiler/Compiler/NFFrontEnd/NFBuiltinFuncs.mo
Expand Up @@ -238,52 +238,57 @@ constant ComponentRef STRING_CREF =

// TODO: Sort these functions ...
constant Function COS_REAL = Function.FUNCTION(Path.IDENT("cos"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function SIN_REAL = Function.FUNCTION(Path.IDENT("sin"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function TAN_REAL = Function.FUNCTION(Path.IDENT("tan"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function COSH_REAL = Function.FUNCTION(Path.IDENT("cosh"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function SINH_REAL = Function.FUNCTION(Path.IDENT("sinh"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function TANH_REAL = Function.FUNCTION(Path.IDENT("tanh"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function EXP_REAL = Function.FUNCTION(Path.IDENT("exp"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function LOG_REAL = Function.FUNCTION(Path.IDENT("log"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function LOG10_REAL = Function.FUNCTION(Path.IDENT("log10"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function ABS_REAL = Function.FUNCTION(Path.IDENT("abs"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function SIGN_REAL = Function.FUNCTION(Path.IDENT("sign"),
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

Expand All @@ -302,6 +307,11 @@ constant Function DIV_INT = Function.FUNCTION(Path.IDENT("div"),
Type.INTEGER(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function DIV_REAL = Function.FUNCTION(Path.IDENT("div"),
InstNode.EMPTY_NODE(), {REAL_PARAM, REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Pointer.createImmutable(FunctionStatus.BUILTIN), Pointer.createImmutable(0));

constant Function FLOOR = Function.FUNCTION(Path.IDENT("floor"),
InstNode.EMPTY_NODE(), {REAL_PARAM}, {REAL_PARAM}, {}, {},
Type.REAL(), DAE.FUNCTION_ATTRIBUTES_BUILTIN, {}, listArray({}),
Expand Down
4 changes: 1 addition & 3 deletions OMCompiler/Compiler/NFFrontEnd/NFFunctionDerivative.mo
Expand Up @@ -180,13 +180,11 @@ public
end toSubMod;

function getOrder
"returns true if the function derivative is of given order"
"returns the order of the given function derivative"
input FunctionDerivative funcDer;
output Integer order;
algorithm
order := match funcDer.order
local
Integer value;
case Expression.INTEGER(value = order) then order;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because the order was not evaluated to be a constant: " + Expression.toString(funcDer.order)});
Expand Down
Expand Up @@ -10,7 +10,7 @@ model builtin_functions
Real y;
Real[14] x(each start=0.0, each fixed=true);
equation
y = sin(x[2]);
y = abs(x[2]);
der(x[1]) = sqrt(y);
der(x[2]) = sin(y);
der(x[3]) = cos(y);
Expand Down Expand Up @@ -52,7 +52,7 @@ simulate(builtin_functions); getErrorString();
// ### Variable:
// Real y
// ### Equation:
// [SCAL] (1) y = sin(x[2]) ($RES_SIM_14)
// [SCAL] (1) y = abs(x[2]) ($RES_SIM_14)
//
// --- Alias of INI[1 | 3] ---
// BLOCK 2: Sliced Equation (status = Solve.EXPLICIT)
Expand Down Expand Up @@ -172,7 +172,7 @@ simulate(builtin_functions); getErrorString();
//
// Result Equations (15/15)
// ****************************************
// (1) [SCAL] (1) $pDER_ODE_JAC.y = cos(x[2]) * $SEED_ODE_JAC.x[2] ($RES_ODE_JAC_0)
// (1) [SCAL] (1) $pDER_ODE_JAC.y = sign(x[2]) * $SEED_ODE_JAC.x[2] ($RES_ODE_JAC_0)
// (2) [SCAL] (1) $pDER_ODE_JAC.$DER.x[1] = 0.5 / y ^ 0.5 * $pDER_ODE_JAC.y ($RES_ODE_JAC_1)
// (3) [SCAL] (1) $pDER_ODE_JAC.$DER.x[2] = cos(y) * $pDER_ODE_JAC.y ($RES_ODE_JAC_2)
// (4) [SCAL] (1) $pDER_ODE_JAC.$DER.x[3] = -sin(y) * $pDER_ODE_JAC.y ($RES_ODE_JAC_3)
Expand Down

0 comments on commit 39088fb

Please sign in to comment.