From 5e03289d257ed2870225706797f3522df1b6662a Mon Sep 17 00:00:00 2001 From: Eric Milles Date: Fri, 8 Mar 2024 17:37:53 -0600 Subject: [PATCH] GROOVY-11335: STC: loop item type from `UnionTypeClassNode` --- .../stc/StaticTypeCheckingVisitor.java | 61 ++++++------ .../transform/stc/UnionTypeClassNode.java | 98 +++++++++---------- .../groovy/transform/stc/LoopsSTCTest.groovy | 14 +++ 3 files changed, 90 insertions(+), 83 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java index 2be11c7c1f0..9a8066f8e4d 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java @@ -2067,33 +2067,34 @@ public void visitForLoop(final ForStatement forLoop) { * @see #inferComponentType */ public static ClassNode inferLoopElementType(final ClassNode collectionType) { - ClassNode componentType = collectionType.getComponentType(); - if (componentType == null) { - if (isOrImplements(collectionType, ITERABLE_TYPE)) { - ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE); - componentType = getCombinedBoundType(col.getGenericsTypes()[0]); - - } else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240 - ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE); - componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes()); - - } else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476 - ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE); - componentType = getCombinedBoundType(col.getGenericsTypes()[0]); - - } else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123 - ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE); - componentType = getCombinedBoundType(col.getGenericsTypes()[0]); - - } else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712 - ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE); - componentType = getCombinedBoundType(col.getGenericsTypes()[0]); - - } else if (isStringType(collectionType)) { - componentType = STRING_TYPE; - } else { - componentType = OBJECT_TYPE; - } + ClassNode componentType; + if (collectionType.isArray()) { // GROOVY-11335 + componentType = collectionType.getComponentType(); + + } else if (isOrImplements(collectionType, ITERABLE_TYPE)) { + ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE); + componentType = getCombinedBoundType(col.getGenericsTypes()[0]); + + } else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240 + ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE); + componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes()); + + } else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476 + ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE); + componentType = getCombinedBoundType(col.getGenericsTypes()[0]); + + } else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712 + ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE); + componentType = getCombinedBoundType(col.getGenericsTypes()[0]); + + } else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123 + ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE); + componentType = getCombinedBoundType(col.getGenericsTypes()[0]); + + } else if (isStringType(collectionType)) { + componentType = STRING_TYPE; + } else { + componentType = OBJECT_TYPE; } return componentType; } @@ -4678,8 +4679,10 @@ protected static ClassNode getGroupOperationResultType(final ClassNode a, final } protected ClassNode inferComponentType(final ClassNode receiverType, final ClassNode subscriptType) { - ClassNode componentType = receiverType.getComponentType(); - if (componentType == null) { + ClassNode componentType = null; + if (receiverType.isArray()) { // GROOVY-11335 + componentType = receiverType.getComponentType(); + } else { MethodCallExpression mce; if (subscriptType != null) { // GROOVY-5521: check for a suitable "getAt(T)" method mce = callX(varX("#", receiverType), "getAt", varX("selector", subscriptType)); diff --git a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java index 90b65606878..cb1377b9b51 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java @@ -35,7 +35,6 @@ import org.codehaus.groovy.transform.ASTTransformation; import java.util.Arrays; -import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashSet; @@ -172,59 +171,51 @@ public void addTransform(final Class transform, fin throw new UnsupportedOperationException(); } - @Override - public boolean declaresInterface(final ClassNode classNode) { - for (ClassNode delegate : delegates) { - if (delegate.declaresInterface(classNode)) return true; - } - return false; - } - @Override public List getAbstractMethods() { - List allMethods = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAbstractMethods()); + answer.addAll(delegate.getAbstractMethods()); } - return allMethods; + return answer; } @Override public List getAllDeclaredMethods() { - List allMethods = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAllDeclaredMethods()); + answer.addAll(delegate.getAllDeclaredMethods()); } - return allMethods; + return answer; } @Override public Set getAllInterfaces() { - Set allMethods = new HashSet(); + Set answer = new HashSet<>(); for (ClassNode delegate : delegates) { - allMethods.addAll(delegate.getAllInterfaces()); + answer.addAll(delegate.getAllInterfaces()); } - return allMethods; + return answer; } @Override public List getAnnotations() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List annotations = delegate.getAnnotations(); - if (annotations != null) nodes.addAll(annotations); + if (annotations != null) answer.addAll(annotations); } - return nodes; + return answer; } @Override public List getAnnotations(final ClassNode type) { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List annotations = delegate.getAnnotations(type); - if (annotations != null) nodes.addAll(annotations); + if (annotations != null) answer.addAll(annotations); } - return nodes; + return answer; } @Override @@ -234,11 +225,11 @@ public ClassNode getComponentType() { @Override public List getDeclaredConstructors() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { - nodes.addAll(delegate.getDeclaredConstructors()); + answer.addAll(delegate.getDeclaredConstructors()); } - return nodes; + return answer; } @Override @@ -261,12 +252,12 @@ public MethodNode getDeclaredMethod(final String name, final Parameter[] paramet @Override public List getDeclaredMethods(final String name) { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List methods = delegate.getDeclaredMethods(name); - if (methods != null) nodes.addAll(methods); + if (methods != null) answer.addAll(methods); } - return nodes; + return answer; } @Override @@ -290,12 +281,12 @@ public FieldNode getField(final String name) { @Override public List getFields() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List fields = delegate.getFields(); - if (fields != null) nodes.addAll(fields); + if (fields != null) answer.addAll(fields); } - return nodes; + return answer; } @Override @@ -305,22 +296,25 @@ public Iterator getInnerClasses() { @Override public ClassNode[] getInterfaces() { - Set nodes = new LinkedHashSet(); + Set answer = new LinkedHashSet<>(); for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getInterfaces(); - if (interfaces != null) Collections.addAll(nodes, interfaces); + if (delegate.isInterface()) { + answer.remove(delegate); answer.add(delegate); + } else { + answer.addAll(Arrays.asList(delegate.getInterfaces())); + } } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return answer.toArray(ClassNode.EMPTY_ARRAY); } @Override public List getMethods() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List methods = delegate.getMethods(); - if (methods != null) nodes.addAll(methods); + if (methods != null) answer.addAll(methods); } - return nodes; + return answer; } @Override @@ -334,12 +328,12 @@ public ClassNode getPlainNodeReference(final boolean skipPrimitives) { @Override public List getProperties() { - List nodes = new LinkedList(); + List answer = new LinkedList<>(); for (ClassNode delegate : delegates) { List properties = delegate.getProperties(); - if (properties != null) nodes.addAll(properties); + if (properties != null) answer.addAll(properties); } - return nodes; + return answer; } @Override @@ -349,22 +343,18 @@ public Class getTypeClass() { @Override public ClassNode[] getUnresolvedInterfaces() { - Set nodes = new LinkedHashSet(); - for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getUnresolvedInterfaces(); - if (interfaces != null) Collections.addAll(nodes, interfaces); - } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return getUnresolvedInterfaces(false); } @Override public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) { - Set nodes = new LinkedHashSet(); - for (ClassNode delegate : delegates) { - ClassNode[] interfaces = delegate.getUnresolvedInterfaces(useRedirect); - if (interfaces != null) Collections.addAll(nodes, interfaces); + ClassNode[] interfaces = getInterfaces(); + if (useRedirect) { + for (int i = 0; i < interfaces.length; ++i) { + interfaces[i] = interfaces[i].redirect(); + } } - return nodes.toArray(ClassNode.EMPTY_ARRAY); + return interfaces; } @Override diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy b/src/test/groovy/transform/stc/LoopsSTCTest.groovy index 11bc0a91d3e..e0562a61400 100644 --- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy +++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy @@ -252,6 +252,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase { ''' } + // GROOVY-11335 + void testForInLoopOnCollection() { + assertScript ''' + def whatever(Collection coll) { + if (coll instanceof Serializable) { + for (item in coll) { + return item.toLowerCase() + } + } + } + assert whatever(['Works']) == 'works' + ''' + } + // GROOVY-6123 void testForInLoopOnEnumeration() { assertScript '''