diff --git a/src/main/java/org/junit/experimental/theories/internal/AllMembersSupplier.java b/src/main/java/org/junit/experimental/theories/internal/AllMembersSupplier.java index a8d7c9a6bd9d..2e1a89f4b2cf 100644 --- a/src/main/java/org/junit/experimental/theories/internal/AllMembersSupplier.java +++ b/src/main/java/org/junit/experimental/theories/internal/AllMembersSupplier.java @@ -4,6 +4,7 @@ import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Collection; +import java.util.Iterator; import java.util.List; import org.junit.Assume; @@ -76,9 +77,11 @@ private void addMultiPointMethods(ParameterSignature sig, List returnType = dataPointsMethod.getReturnType(); - if (returnType.isArray() && sig.canPotentiallyAcceptType(returnType.getComponentType())) { + if ((returnType.isArray() && sig.canPotentiallyAcceptType(returnType.getComponentType())) || + Iterable.class.isAssignableFrom(returnType)) { try { - addArrayValues(sig, dataPointsMethod.getName(), list, dataPointsMethod.invokeExplosively(null)); + addDataPointsValues(returnType, sig, dataPointsMethod.getName(), list, + dataPointsMethod.invokeExplosively(null)); } catch (Throwable throwable) { DataPoints annotation = dataPointsMethod.getAnnotation(DataPoints.class); if (annotation != null && isAssignableToAnyOf(annotation.ignoredExceptions(), throwable)) { @@ -101,9 +104,10 @@ private void addSinglePointMethods(ParameterSignature sig, List list) { for (final Field field : getDataPointsFields(sig)) { - addArrayValues(sig, field.getName(), list, getStaticFieldValue(field)); + Class type = field.getType(); + addDataPointsValues(type, sig, field.getName(), list, getStaticFieldValue(field)); } - } + } private void addSinglePointFields(ParameterSignature sig, List list) { for (final Field field : getSingleDataPointFields(sig)) { @@ -114,6 +118,16 @@ private void addSinglePointFields(ParameterSignature sig, List type, ParameterSignature sig, String name, + List list, Object value) { + if (type.isArray()) { + addArrayValues(sig, name, list, value); + } + else if (Iterable.class.isAssignableFrom(type)) { + addIterableValues(sig, name, list, (Iterable) value); + } + } private void addArrayValues(ParameterSignature sig, String name, List list, Object array) { for (int i = 0; i < Array.getLength(array); i++) { @@ -123,6 +137,18 @@ private void addArrayValues(ParameterSignature sig, String name, List list, Iterable iterable) { + Iterator iterator = iterable.iterator(); + int i = 0; + while (iterator.hasNext()) { + Object value = iterator.next(); + if (sig.canAcceptValue(value)) { + list.add(PotentialAssignment.forValue(name + "[" + i + "]", value)); + } + i += 1; + } + } private Object getStaticFieldValue(final Field field) { try { diff --git a/src/test/java/org/junit/tests/experimental/theories/internal/AllMembersSupplierTest.java b/src/test/java/org/junit/tests/experimental/theories/internal/AllMembersSupplierTest.java index eac9f2d9ed0d..a3ee3635bd48 100644 --- a/src/test/java/org/junit/tests/experimental/theories/internal/AllMembersSupplierTest.java +++ b/src/test/java/org/junit/tests/experimental/theories/internal/AllMembersSupplierTest.java @@ -6,6 +6,7 @@ import static org.junit.Assert.assertThat; import static org.junit.tests.experimental.theories.TheoryTestUtils.potentialAssignments; +import java.util.Arrays; import java.util.List; import org.junit.Rule; @@ -23,7 +24,6 @@ public class AllMembersSupplierTest { @Rule public ExpectedException expected = ExpectedException.none(); - public static class HasDataPointsArrayField { @DataPoints public static String[] list = new String[] { "qwe", "asd" }; @@ -153,4 +153,57 @@ private List allMemberValuesFor(Class testClass, testClass.getConstructor(constructorParameterTypes)) .get(0)); } -} + + public static class HasDataPointsListField { + @DataPoints + public static List list = Arrays.asList("one", "two"); + + @Theory + public void theory(String param) { + } + } + + @Test + public void dataPointsCollectionFieldsShouldBeRecognized() throws Throwable { + List assignments = potentialAssignments( + HasDataPointsListField.class.getMethod("theory", String.class)); + + assertEquals(2, assignments.size()); + } + + public static class HasDataPointsListMethod { + @DataPoints + public static List getList() { + return Arrays.asList("one", "two"); + } + + @Theory + public void theory(String param) { + } + } + + @Test + public void dataPointsCollectionMethodShouldBeRecognized() throws Throwable { + List assignments = potentialAssignments( + HasDataPointsListMethod.class.getMethod("theory", String.class)); + + assertEquals(2, assignments.size()); + } + + public static class HasDataPointsListFieldWithOverlyGenericTypes { + @DataPoints + public static List list = Arrays.asList("string", new Object()); + + @Theory + public void theory(String param) { + } + } + + @Test + public void dataPointsCollectionShouldBeRecognizedIgnoringStrangeTypes() throws Throwable { + List assignments = potentialAssignments( + HasDataPointsListFieldWithOverlyGenericTypes.class.getMethod("theory", String.class)); + + assertEquals(1, assignments.size()); + } +} \ No newline at end of file