From 16f6be8639c15e227efc0154604b6bd8f8ffe1ae Mon Sep 17 00:00:00 2001 From: Eric Milles Date: Thu, 19 May 2022 12:28:46 -0500 Subject: [PATCH] GROOVY-10271, GROOVY-10272: STC: process closure in ternary expression --- .../stc/StaticTypeCheckingVisitor.java | 28 ++++++++++++---- .../stc/TernaryOperatorSTCTest.groovy | 33 ++++++++++++++++++- 2 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java index 88df814abbf..91834456a7c 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java @@ -4176,12 +4176,10 @@ public void visitTernaryExpression(final TernaryExpression expression) { } Expression trueExpression = expression.getTrueExpression(); ClassNode typeOfTrue = findCurrentInstanceOfClass(trueExpression, null); - trueExpression.visit(this); - if (typeOfTrue == null) typeOfTrue = getType(trueExpression); + typeOfTrue = Optional.ofNullable(typeOfTrue).orElse(visitValueExpression(trueExpression)); typeCheckingContext.popTemporaryTypeInfo(); // instanceof doesn't apply to false branch Expression falseExpression = expression.getFalseExpression(); - falseExpression.visit(this); - ClassNode typeOfFalse = getType(falseExpression); + ClassNode typeOfFalse = visitValueExpression(falseExpression); ClassNode resultType; if (isNullConstant(trueExpression) && isNullConstant(falseExpression)) { // GROOVY-5523 @@ -4201,13 +4199,25 @@ && isOrImplements(typeOfFalse, typeOfTrue))) { // List/Collection/Iterable : [] popAssignmentTracking(oldTracker); } + /** + * @param expr true or false branch of ternary expression + * @return the inferred type of {@code expr} + */ + private ClassNode visitValueExpression(final Expression expr) { + if (expr instanceof ClosureExpression) { + ClassNode targetType = checkForTargetType(expr, null); + if (isFunctionalInterface(targetType)) + processFunctionalInterfaceAssignment(targetType, expr); + } + expr.visit(this); + return getType(expr); + } + /** * @param expr true or false branch of ternary expression * @param type the inferred type of {@code expr} */ private ClassNode checkForTargetType(final Expression expr, final ClassNode type) { - ClassNode sourceType = Optional.ofNullable(getInferredReturnType(expr)).orElse(type); - ClassNode targetType = null; MethodNode enclosingMethod = typeCheckingContext.getEnclosingMethod(); BinaryExpression enclosingExpression = typeCheckingContext.getEnclosingBinaryExpression(); @@ -4222,6 +4232,12 @@ && isTypeSource(expr, enclosingMethod)) { targetType = enclosingMethod.getReturnType(); } + if (expr instanceof ClosureExpression) { // GROOVY-10271, GROOVY-10272 + return isSAMType(targetType) ? targetType : type; + } + + ClassNode sourceType = Optional.ofNullable(getInferredReturnType(expr)).orElse(type); + if (expr instanceof ConstructorCallExpression) { // GROOVY-9972, GROOVY-9983 // GROOVY-10114: type parameter(s) could be inferred from call arguments if (targetType == null) targetType = sourceType.getPlainNodeReference(); diff --git a/src/test/groovy/transform/stc/TernaryOperatorSTCTest.groovy b/src/test/groovy/transform/stc/TernaryOperatorSTCTest.groovy index c8708bccea3..1c5a7da7b77 100644 --- a/src/test/groovy/transform/stc/TernaryOperatorSTCTest.groovy +++ b/src/test/groovy/transform/stc/TernaryOperatorSTCTest.groovy @@ -166,6 +166,37 @@ class TernaryOperatorSTCTest extends StaticTypeCheckingTestCase { ''' } + // GROOVY-10271 + void testFunctionalInterfaceTarget1() { + ['true', 'false'].each { flag -> + assertScript """import java.util.function.Supplier + + Supplier x = { -> 1 } + Supplier y = $flag ? x : { -> 2 } + + assert y.get() == ($flag ? 1 : 2) + """ + } + } + + // GROOVY-10272 + void testFunctionalInterfaceTarget2() { + assertScript ''' + import java.util.function.Function + + Function x + if (true) { + x = { a -> a.longValue() } + } else { + x = { Integer b -> (Long)b } + } + assert x.apply(42) == 42L + + Function y = (true ? { a -> a.longValue() } : { Integer b -> (Long)b }) + assert y.apply(42) == 42L + ''' + } + // GROOVY-10357 void testAbstractMethodDefault() { assertScript ''' @@ -186,7 +217,7 @@ class TernaryOperatorSTCTest extends StaticTypeCheckingTestCase { } // GROOVY-10358 - void testCommonInterface() { + void testCommonInterface1() { assertScript ''' interface I { int m(int i)