Skip to content

Commit

Permalink
GROOVY-11335: STC: loop item type from UnionTypeClassNode
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-milles committed Mar 9, 2024
1 parent 84e1299 commit 5e03289
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 83 deletions.
Expand Up @@ -2067,33 +2067,34 @@ public void visitForLoop(final ForStatement forLoop) {
* @see #inferComponentType
*/
public static ClassNode inferLoopElementType(final ClassNode collectionType) {
ClassNode componentType = collectionType.getComponentType();
if (componentType == null) {
if (isOrImplements(collectionType, ITERABLE_TYPE)) {
ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes());

} else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476
ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712
ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isStringType(collectionType)) {
componentType = STRING_TYPE;
} else {
componentType = OBJECT_TYPE;
}
ClassNode componentType;
if (collectionType.isArray()) { // GROOVY-11335
componentType = collectionType.getComponentType();

} else if (isOrImplements(collectionType, ITERABLE_TYPE)) {
ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes());

} else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476
ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, Iterator_TYPE)) { // GROOVY-10712
ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
componentType = getCombinedBoundType(col.getGenericsTypes()[0]);

} else if (isStringType(collectionType)) {
componentType = STRING_TYPE;
} else {
componentType = OBJECT_TYPE;
}
return componentType;
}
Expand Down Expand Up @@ -4678,8 +4679,10 @@ protected static ClassNode getGroupOperationResultType(final ClassNode a, final
}

protected ClassNode inferComponentType(final ClassNode receiverType, final ClassNode subscriptType) {
ClassNode componentType = receiverType.getComponentType();
if (componentType == null) {
ClassNode componentType = null;
if (receiverType.isArray()) { // GROOVY-11335
componentType = receiverType.getComponentType();
} else {
MethodCallExpression mce;
if (subscriptType != null) { // GROOVY-5521: check for a suitable "getAt(T)" method
mce = callX(varX("#", receiverType), "getAt", varX("selector", subscriptType));
Expand Down
Expand Up @@ -35,7 +35,6 @@
import org.codehaus.groovy.transform.ASTTransformation;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -172,59 +171,51 @@ public void addTransform(final Class<? extends ASTTransformation> transform, fin
throw new UnsupportedOperationException();
}

@Override
public boolean declaresInterface(final ClassNode classNode) {
for (ClassNode delegate : delegates) {
if (delegate.declaresInterface(classNode)) return true;
}
return false;
}

@Override
public List<MethodNode> getAbstractMethods() {
List<MethodNode> allMethods = new LinkedList<MethodNode>();
List<MethodNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
allMethods.addAll(delegate.getAbstractMethods());
answer.addAll(delegate.getAbstractMethods());
}
return allMethods;
return answer;
}

@Override
public List<MethodNode> getAllDeclaredMethods() {
List<MethodNode> allMethods = new LinkedList<MethodNode>();
List<MethodNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
allMethods.addAll(delegate.getAllDeclaredMethods());
answer.addAll(delegate.getAllDeclaredMethods());
}
return allMethods;
return answer;
}

@Override
public Set<ClassNode> getAllInterfaces() {
Set<ClassNode> allMethods = new HashSet<ClassNode>();
Set<ClassNode> answer = new HashSet<>();
for (ClassNode delegate : delegates) {
allMethods.addAll(delegate.getAllInterfaces());
answer.addAll(delegate.getAllInterfaces());
}
return allMethods;
return answer;
}

@Override
public List<AnnotationNode> getAnnotations() {
List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
List<AnnotationNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<AnnotationNode> annotations = delegate.getAnnotations();
if (annotations != null) nodes.addAll(annotations);
if (annotations != null) answer.addAll(annotations);
}
return nodes;
return answer;
}

@Override
public List<AnnotationNode> getAnnotations(final ClassNode type) {
List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
List<AnnotationNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<AnnotationNode> annotations = delegate.getAnnotations(type);
if (annotations != null) nodes.addAll(annotations);
if (annotations != null) answer.addAll(annotations);
}
return nodes;
return answer;
}

@Override
Expand All @@ -234,11 +225,11 @@ public ClassNode getComponentType() {

@Override
public List<ConstructorNode> getDeclaredConstructors() {
List<ConstructorNode> nodes = new LinkedList<ConstructorNode>();
List<ConstructorNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
nodes.addAll(delegate.getDeclaredConstructors());
answer.addAll(delegate.getDeclaredConstructors());
}
return nodes;
return answer;
}

@Override
Expand All @@ -261,12 +252,12 @@ public MethodNode getDeclaredMethod(final String name, final Parameter[] paramet

@Override
public List<MethodNode> getDeclaredMethods(final String name) {
List<MethodNode> nodes = new LinkedList<MethodNode>();
List<MethodNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<MethodNode> methods = delegate.getDeclaredMethods(name);
if (methods != null) nodes.addAll(methods);
if (methods != null) answer.addAll(methods);
}
return nodes;
return answer;
}

@Override
Expand All @@ -290,12 +281,12 @@ public FieldNode getField(final String name) {

@Override
public List<FieldNode> getFields() {
List<FieldNode> nodes = new LinkedList<FieldNode>();
List<FieldNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<FieldNode> fields = delegate.getFields();
if (fields != null) nodes.addAll(fields);
if (fields != null) answer.addAll(fields);
}
return nodes;
return answer;
}

@Override
Expand All @@ -305,22 +296,25 @@ public Iterator<InnerClassNode> getInnerClasses() {

@Override
public ClassNode[] getInterfaces() {
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
Set<ClassNode> answer = new LinkedHashSet<>();
for (ClassNode delegate : delegates) {
ClassNode[] interfaces = delegate.getInterfaces();
if (interfaces != null) Collections.addAll(nodes, interfaces);
if (delegate.isInterface()) {
answer.remove(delegate); answer.add(delegate);
} else {
answer.addAll(Arrays.asList(delegate.getInterfaces()));
}
}
return nodes.toArray(ClassNode.EMPTY_ARRAY);
return answer.toArray(ClassNode.EMPTY_ARRAY);
}

@Override
public List<MethodNode> getMethods() {
List<MethodNode> nodes = new LinkedList<MethodNode>();
List<MethodNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<MethodNode> methods = delegate.getMethods();
if (methods != null) nodes.addAll(methods);
if (methods != null) answer.addAll(methods);
}
return nodes;
return answer;
}

@Override
Expand All @@ -334,12 +328,12 @@ public ClassNode getPlainNodeReference(final boolean skipPrimitives) {

@Override
public List<PropertyNode> getProperties() {
List<PropertyNode> nodes = new LinkedList<PropertyNode>();
List<PropertyNode> answer = new LinkedList<>();
for (ClassNode delegate : delegates) {
List<PropertyNode> properties = delegate.getProperties();
if (properties != null) nodes.addAll(properties);
if (properties != null) answer.addAll(properties);
}
return nodes;
return answer;
}

@Override
Expand All @@ -349,22 +343,18 @@ public Class getTypeClass() {

@Override
public ClassNode[] getUnresolvedInterfaces() {
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
for (ClassNode delegate : delegates) {
ClassNode[] interfaces = delegate.getUnresolvedInterfaces();
if (interfaces != null) Collections.addAll(nodes, interfaces);
}
return nodes.toArray(ClassNode.EMPTY_ARRAY);
return getUnresolvedInterfaces(false);
}

@Override
public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) {
Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
for (ClassNode delegate : delegates) {
ClassNode[] interfaces = delegate.getUnresolvedInterfaces(useRedirect);
if (interfaces != null) Collections.addAll(nodes, interfaces);
ClassNode[] interfaces = getInterfaces();
if (useRedirect) {
for (int i = 0; i < interfaces.length; ++i) {
interfaces[i] = interfaces[i].redirect();
}
}
return nodes.toArray(ClassNode.EMPTY_ARRAY);
return interfaces;
}

@Override
Expand Down
14 changes: 14 additions & 0 deletions src/test/groovy/transform/stc/LoopsSTCTest.groovy
Expand Up @@ -252,6 +252,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
'''
}

// GROOVY-11335
void testForInLoopOnCollection() {
assertScript '''
def whatever(Collection<String> coll) {
if (coll instanceof Serializable) {
for (item in coll) {
return item.toLowerCase()
}
}
}
assert whatever(['Works']) == 'works'
'''
}

// GROOVY-6123
void testForInLoopOnEnumeration() {
assertScript '''
Expand Down

0 comments on commit 5e03289

Please sign in to comment.