Skip to content

Commit 89d6008

Browse files
authored
[NB] make correct div operators when differentiating (#13626)
1 parent 18fb60d commit 89d6008

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

OMCompiler/Compiler/NBackEnd/Util/NBDifferentiate.mo

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,8 @@ public
17651765
(_, sizeClass) := Operator.classify(operator);
17661766
addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
17671767
mulOp := Operator.fromClassification((NFOperator.MathClassification.MULTIPLICATION, sizeClass), operator.ty);
1768-
powOp := Operator.fromClassification((NFOperator.MathClassification.POWER, sizeClass), operator.ty);
1768+
powOp := Operator.fromClassification((NFOperator.MathClassification.POWER,
1769+
Operator.combineSizeClassification(sizeClass, NFOperator.SizeClassification.SCALAR)), operator.ty);
17691770
then (Expression.MULTARY(
17701771
{Expression.MULTARY(
17711772
{Expression.BINARY(exp1, mulOp, diffExp2)}, // fg'
@@ -1894,7 +1895,8 @@ public
18941895
// create addition and power operator
18951896
(_, sizeClass) := Operator.classify(operator);
18961897
addOp := Operator.fromClassification((NFOperator.MathClassification.ADDITION, sizeClass), operator.ty);
1897-
powOp := Operator.fromClassification((NFOperator.MathClassification.POWER, sizeClass), operator.ty);
1898+
powOp := Operator.fromClassification((NFOperator.MathClassification.POWER,
1899+
Operator.combineSizeClassification(sizeClass, NFOperator.SizeClassification.SCALAR)), operator.ty);
18981900
// f'
18991901
(diff_arguments, diffArguments) := differentiateMultaryMultiplicationArgs(arguments, diffArguments, operator);
19001902
diff_enumerator := Expression.MULTARY(diff_arguments, {}, addOp);

OMCompiler/Compiler/NFFrontEnd/NFOperator.mo

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,19 @@ public
853853
(_, scl) := classify(op);
854854
end getSizeClassification;
855855

856+
function combineSizeClassification
857+
input SizeClassification scl1;
858+
input SizeClassification scl2;
859+
output SizeClassification scl;
860+
algorithm
861+
scl := match (scl1, scl2)
862+
// Todo: more cases?
863+
case (SizeClassification.ELEMENT_WISE, SizeClassification.SCALAR) then SizeClassification.ARRAY_SCALAR;
864+
case (SizeClassification.SCALAR, SizeClassification.ELEMENT_WISE) then SizeClassification.SCALAR_ARRAY;
865+
else scl1;
866+
end match;
867+
end combineSizeClassification;
868+
856869
function isDashClassification
857870
input MathClassification mcl;
858871
output Boolean b;

0 commit comments

Comments
 (0)