diff --git a/Compiler/NFFrontEnd/NFBuiltinCall.mo b/Compiler/NFFrontEnd/NFBuiltinCall.mo index a868169d2a9..9430f4dc577 100644 --- a/Compiler/NFFrontEnd/NFBuiltinCall.mo +++ b/Compiler/NFFrontEnd/NFBuiltinCall.mo @@ -59,6 +59,7 @@ protected import Typing = NFTyping; import Util; import ExpandExp = NFExpandExp; + import Operator = NFOperator; public function needSpecialHandling @@ -119,12 +120,12 @@ public case "ones" then typeZerosOnesCall("ones", call, origin, info); case "potentialRoot" then typePotentialRootCall(call, origin, info); case "pre" then typePreCall(call, origin, info); - case "product" then typeSumProductCall(call, origin, info); + case "product" then typeProductCall(call, origin, info); case "root" then typeRootCall(call, origin, info); case "rooted" then typeRootedCall(call, origin, info); case "scalar" then typeScalarCall(call, origin, info); case "smooth" then typeSmoothCall(call, origin, info); - case "sum" then typeSumProductCall(call, origin, info); + case "sum" then typeSumCall(call, origin, info); case "symmetric" then typeSymmetricCall(call, origin, info); case "terminal" then typeDiscreteCall(call, origin, info); case "transpose" then typeTransposeCall(call, origin, info); @@ -681,23 +682,109 @@ protected // fix return type. end typeMinMaxCall; - function typeSumProductCall + function typeSumCall input Call call; input ExpOrigin.Type origin; input SourceInfo info; output Expression callExp; output Type ty; - output Variability var; + output Variability variability; protected - Call argtycall; + ComponentRef fn_ref; + list args; + list named_args; + Expression arg; + Function fn; + Boolean expanded; + Operator op; algorithm - // TODO: Rewrite this whole thing. - argtycall := Call.typeMatchNormalCall(call, origin, info); - argtycall := Call.unboxArgs(argtycall); - ty := Call.typeOf(argtycall); - var := Call.variability(argtycall); - callExp := Expression.CALL(argtycall); - end typeSumProductCall; + Call.UNTYPED_CALL(ref = fn_ref, arguments = args, named_args = named_args) := call; + assertNoNamedParams("sum", named_args, info); + + if listLength(args) <> 1 then + Error.addSourceMessageAndFail(Error.NO_MATCHING_FUNCTION_FOUND_NFINST, + {Call.toString(call), "sum(Any[:, ...]) => Any"}, info); + end if; + + (arg, ty, variability) := Typing.typeExp(listHead(args), origin, info); + ty := Type.arrayElementType(ty); + + if intBitAnd(origin, intBitOr(ExpOrigin.FUNCTION, ExpOrigin.ALGORITHM)) == 0 then + (arg, expanded) := ExpandExp.expand(arg); + else + expanded := false; + end if; + + if expanded then + args := Expression.arrayScalarElements(arg); + + if listEmpty(args) then + callExp := Expression.makeZero(ty); + else + callExp :: args := args; + op := Operator.makeAdd(ty); + + for e in args loop + callExp := Expression.BINARY(callExp, op, e); + end for; + end if; + else + {fn} := Function.typeRefCache(fn_ref); + callExp := Expression.CALL(Call.makeTypedCall(fn, {arg}, variability, ty)); + end if; + end typeSumCall; + + function typeProductCall + input Call call; + input ExpOrigin.Type origin; + input SourceInfo info; + output Expression callExp; + output Type ty; + output Variability variability; + protected + ComponentRef fn_ref; + list args; + list named_args; + Expression arg; + Function fn; + Boolean expanded; + Operator op; + algorithm + Call.UNTYPED_CALL(ref = fn_ref, arguments = args, named_args = named_args) := call; + assertNoNamedParams("product", named_args, info); + + if listLength(args) <> 1 then + Error.addSourceMessageAndFail(Error.NO_MATCHING_FUNCTION_FOUND_NFINST, + {Call.toString(call), "product(Any[:, ...]) => Any"}, info); + end if; + + (arg, ty, variability) := Typing.typeExp(listHead(args), origin, info); + ty := Type.arrayElementType(ty); + + if intBitAnd(origin, intBitOr(ExpOrigin.FUNCTION, ExpOrigin.ALGORITHM)) == 0 then + (arg, expanded) := ExpandExp.expand(arg); + else + expanded := false; + end if; + + if expanded then + args := Expression.arrayScalarElements(arg); + + if listEmpty(args) then + callExp := Expression.makeOne(ty); + else + callExp :: args := args; + op := Operator.makeMul(ty); + + for e in args loop + callExp := Expression.BINARY(callExp, op, e); + end for; + end if; + else + {fn} := Function.typeRefCache(fn_ref); + callExp := Expression.CALL(Call.makeTypedCall(fn, {arg}, variability, ty)); + end if; + end typeProductCall; function typeSmoothCall input Call call; @@ -884,7 +971,6 @@ protected list args; list named_args; Expression arg; - Variability var; Function fn; Boolean expanded; algorithm