From d090ace87f34a08fa6aa16efb930b9cf25c6d192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=96stlund?= Date: Fri, 14 Dec 2018 18:45:35 +0100 Subject: [PATCH] [NF] Improve operator overloading. - Implement scalar*array, array*scalar and array/scalar for overloaded operators. - Improve TypeCheck.implicitConstructAndMatch so that it checks that the constructed argument actually matches the expected type for the operator, to avoid it matching e.g. scalars with operators that only take arrays. Belonging to [master]: - OpenModelica/OMCompiler#2837 - OpenModelica/OpenModelica-testsuite#1095 --- Compiler/NFFrontEnd/NFTypeCheck.mo | 291 +++++++++++++++++++++++++++-- 1 file changed, 271 insertions(+), 20 deletions(-) diff --git a/Compiler/NFFrontEnd/NFTypeCheck.mo b/Compiler/NFFrontEnd/NFTypeCheck.mo index cf4220ae6d..5c65a947e2 100644 --- a/Compiler/NFFrontEnd/NFTypeCheck.mo +++ b/Compiler/NFFrontEnd/NFTypeCheck.mo @@ -250,6 +250,12 @@ algorithm if oop == Op.ADD or oop == Op.SUB then (outExp, outType) := checkOverloadedBinaryArrayAddSub(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + elseif oop == Op.MUL then + (outExp, outType) := + checkOverloadedBinaryArrayMul(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + elseif oop == Op.DIV then + (outExp, outType) := + checkOverloadedBinaryArrayDiv(exp1, type1, var1, op, exp2, type2, var2, candidates, info); else printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, type2}, info, showErrors); end if; @@ -363,6 +369,210 @@ algorithm end match; end checkOverloadedBinaryArrayAddSub2; +function checkOverloadedBinaryArrayMul + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +protected + Boolean valid; + list dims1, dims2; + Dimension dim11, dim12, dim21, dim22; +algorithm + dims1 := Type.arrayDims(type1); + dims2 := Type.arrayDims(type2); + + (valid, outExp) := match (dims1, dims2) + // scalar * array = array + case ({}, {_}) + algorithm + outExp := checkOverloadedBinaryScalarArray(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + then + (true, outExp); + // array * scalar = array + case ({_}, {}) + algorithm + outExp := checkOverloadedBinaryArrayScalar(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + then + (true, outExp); + // matrix[n, m] * vector[m] = vector[n] + case ({dim11, dim12}, {dim21}) + algorithm + valid := Dimension.isEqual(dim12, dim21); + // TODO: Implement me! + outExp := Expression.BINARY(exp1, op, exp2); + valid := false; + then + (valid, outExp); + // matrix[n, m] * matrix[m, p] = vector[n, p] + case ({dim11, dim12}, {dim21, dim22}) + algorithm + valid := Dimension.isEqual(dim12, dim21); + // TODO: Implement me! + outExp := Expression.BINARY(exp1, op, exp2); + valid := false; + then + (valid, outExp); + // scalar * scalar should never get here. + // vector * vector and vector * matrix are undefined for overloaded operators. + else (false, Expression.BINARY(exp1, op, exp2)); + end match; + + if not valid then + printUnresolvableTypeError(outExp, {type1, type2}, info); + end if; + + outType := Expression.typeOf(outExp); +end checkOverloadedBinaryArrayMul; + +function checkOverloadedBinaryScalarArray + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +algorithm + (outExp, outType) := checkOverloadedBinaryScalarArray2( + exp1, type1, var1, op, ExpandExp.expand(exp2), type2, var2, candidates, info); +end checkOverloadedBinaryScalarArray; + +function checkOverloadedBinaryScalarArray2 + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +protected + list expl; + Type ty; +algorithm + (outExp, outType) := match exp2 + case Expression.ARRAY(elements = {}) + algorithm + try + ty := Type.unliftArray(type2); + (_, outType) := matchOverloadedBinaryOperator( + exp1, type1, var1, op, Expression.EMPTY(type2), ty, var2, candidates, info, showErrors = false); + else + printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, exp2.ty}, info); + end try; + + outType := Type.setArrayElementType(exp2.ty, outType); + then + (Expression.makeArray(outType, {}), outType); + + case Expression.ARRAY(elements = expl) + algorithm + ty := Type.unliftArray(type2); + expl := list(checkOverloadedBinaryScalarArray2(exp1, type1, var1, op, e, ty, var2, candidates, info) for e in expl); + outType := Type.setArrayElementType(exp2.ty, Expression.typeOf(listHead(expl))); + then + (Expression.makeArray(outType, expl), outType); + + else matchOverloadedBinaryOperator(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + end match; +end checkOverloadedBinaryScalarArray2; + +function checkOverloadedBinaryArrayScalar + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +algorithm + (outExp, outType) := checkOverloadedBinaryArrayScalar2( + ExpandExp.expand(exp1), type1, var1, op, exp2, type2, var2, candidates, info); +end checkOverloadedBinaryArrayScalar; + +function checkOverloadedBinaryArrayScalar2 + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +protected + Expression e1; + list expl; + Type ty; +algorithm + (outExp, outType) := match exp1 + case Expression.ARRAY(elements = {}) + algorithm + try + ty := Type.unliftArray(type1); + (_, outType) := matchOverloadedBinaryOperator( + Expression.EMPTY(type1), ty, var1, op, exp2, type2, var2, candidates, info, showErrors = false); + else + printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, exp1.ty}, info); + end try; + + outType := Type.setArrayElementType(exp1.ty, outType); + then + (Expression.makeArray(outType, {}), outType); + + case Expression.ARRAY(elements = expl) + algorithm + ty := Type.unliftArray(type1); + expl := list(checkOverloadedBinaryArrayScalar2(e, ty, var1, op, exp2, type2, var2, candidates, info) for e in expl); + outType := Type.setArrayElementType(exp1.ty, Expression.typeOf(listHead(expl))); + then + (Expression.makeArray(outType, expl), outType); + + else matchOverloadedBinaryOperator(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + end match; +end checkOverloadedBinaryArrayScalar2; + +function checkOverloadedBinaryArrayDiv + input Expression exp1; + input Type type1; + input Variability var1; + input Operator op; + input Expression exp2; + input Type type2; + input Variability var2; + input list candidates; + input SourceInfo info; + output Expression outExp; + output Type outType; +algorithm + if Type.isArray(type1) and Type.isScalar(type2) then + (outExp, outType) := checkOverloadedBinaryArrayScalar(exp1, type1, var1, op, exp2, type2, var2, candidates, info); + else + printUnresolvableTypeError(Expression.BINARY(exp1, op, exp2), {type1, type2}, info); + end if; +end checkOverloadedBinaryArrayDiv; + function implicitConstructAndMatch input list candidates; input Expression inExp1; @@ -381,32 +591,32 @@ protected Function operfn; list, Variability>> matchedfuncs = {}; Expression exp1,exp2; - Type ty; + Type ty, arg1_ty, arg2_ty; Variability var; + Boolean matched; + SourceInfo arg1_info, arg2_info; algorithm exp1 := inExp1; exp2 := inExp2; for fn in candidates loop in1 :: in2 :: _ := fn.inputs; - (_, _, mk1) := matchTypes(InstNode.getType(in1),inType1,inExp1,false); - (_, _, mk2) := matchTypes(InstNode.getType(in2),inType2,inExp2,false); - - // If the first argument matched the expected one, then we try - // to construct the second argument to the class of the second input. - if mk1 == MatchKind.EXACT then - // We only want overloaded constructors when trying to implicitly construct. Default constructors are not considered. - scope := InstNode.classScope(in2); - fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'",{}),scope,InstNode.info(in2)); - exp2 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {inExp2}, {}, scope)); - (exp2, ty, var) := Call.typeCall(exp2, 0, InstNode.info(in1)); - matchedfuncs := (fn,{inExp1,exp2}, var)::matchedfuncs; - elseif mk2 == MatchKind.EXACT then - // We only want overloaded constructors when trying to implicitly construct. Default constructors are not considered. - scope := InstNode.classScope(in1); - fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'",{}),scope,InstNode.info(in1)); - exp1 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {inExp1}, {}, scope)); - (exp1, ty, var) := Call.typeCall(exp1, 0, InstNode.info(in2)); - matchedfuncs := (fn,{exp1,inExp2},var)::matchedfuncs; + arg1_ty := InstNode.getType(in1); + arg2_ty := InstNode.getType(in2); + arg1_info := InstNode.info(in1); + arg2_info := InstNode.info(in2); + + // Try to implicitly construct a matching record from the first argument. + (matchedfuncs, matched) := + implicitConstructAndMatch2(inExp1, inType1, inExp2, arg1_ty, + arg1_info, arg2_ty, arg2_info, InstNode.classScope(in2), fn, false, matchedfuncs); + + if matched then + continue; end if; + + // Try to implicitly construct a matching record from the second argument. + (matchedfuncs, matched) := + implicitConstructAndMatch2(inExp2, inType2, inExp1, arg2_ty, + arg2_info, arg1_ty, arg1_info, InstNode.classScope(in1), fn, true, matchedfuncs); end for; if listLength(matchedfuncs) == 1 then @@ -421,6 +631,47 @@ algorithm end if; end implicitConstructAndMatch; +function implicitConstructAndMatch2 + input Expression exp1; + input Type type1; + input Expression exp2; + input Type paramType1; + input SourceInfo paramInfo1; + input Type paramType2; + input SourceInfo paramInfo2; + input InstNode scope; + input Function fn; + input Boolean reverseArgs; + input output list, Variability>> matchedFns; + output Boolean matched; +protected + ComponentRef fn_ref; + Expression e1, e2; + MatchKind mk; + Variability var; + Type ty; +algorithm + (e1, _, mk) := matchTypes(paramType1, type1, exp1, false); + + // We only want overloaded constructors when trying to implicitly construct. + // Default constructors are not considered. + if mk == MatchKind.EXACT then + fn_ref := Function.instFunction(Absyn.CREF_IDENT("'constructor'", {}), scope, paramInfo2); + e2 := Expression.CALL(NFCall.UNTYPED_CALL(fn_ref, {exp2}, {}, scope)); + (e2, ty, var) := Call.typeCall(e2, 0, paramInfo1); + (_, _, mk) := matchTypes(paramType2, ty, e2, false); + + if mk == MatchKind.EXACT then + matchedFns := (fn, if reverseArgs then {e2, e1} else {e1, e2}, var) :: matchedFns; + matched := true; + else + matched := false; + end if; + else + matched := false; + end if; +end implicitConstructAndMatch2; + //function checkValidBinaryOperatorOverload // input String oper_name; // input Function oper_func;