diff --git a/src/main/java/graphql/annotations/GraphQLAnnotations.java b/src/main/java/graphql/annotations/GraphQLAnnotations.java index bf697b6e..ca414cc7 100644 --- a/src/main/java/graphql/annotations/GraphQLAnnotations.java +++ b/src/main/java/graphql/annotations/GraphQLAnnotations.java @@ -188,7 +188,6 @@ public GraphQLInterfaceType.Builder getIfaceBuilder(Class iface) throws Graph return builder; } - public static GraphQLInterfaceType.Builder ifaceBuilder(Class iface) throws GraphQLAnnotationsException, IllegalAccessException { return getInstance().getIfaceBuilder(iface); @@ -410,22 +409,7 @@ protected GraphQLFieldDefinition getField(Field field) throws GraphQLAnnotations GraphQLDataFetcher dataFetcher = field.getAnnotation(GraphQLDataFetcher.class); DataFetcher actualDataFetcher = null; if (nonNull(dataFetcher)) { - final String[] args; - if (dataFetcher.firstArgIsTargetName()) { - args = Stream.concat(Stream.of(field.getName()), stream(dataFetcher.args())).toArray(String[]::new); - } else { - args = dataFetcher.args(); - } - if (args.length == 0) { - actualDataFetcher = newInstance(dataFetcher.value()); - } else { - try { - final Constructor ctr = dataFetcher.value().getDeclaredConstructor( - stream(args).map(v -> String.class).toArray(Class[]::new)); - actualDataFetcher = constructNewInstance(ctr, (Object[]) args); - } catch (final NoSuchMethodException e) { - } - } + actualDataFetcher = constructDataFetcher(field.getName(), dataFetcher); } if (actualDataFetcher == null) { @@ -469,6 +453,26 @@ protected GraphQLFieldDefinition getField(Field field) throws GraphQLAnnotations return new GraphQLFieldDefinitionWrapper(builder.build()); } + private DataFetcher constructDataFetcher(String fieldName, GraphQLDataFetcher annotatedDataFetcher) { + final String[] args; + if ( annotatedDataFetcher.firstArgIsTargetName() ) { + args = Stream.concat(Stream.of(fieldName), stream(annotatedDataFetcher.args())).toArray(String[]::new); + } else { + args = annotatedDataFetcher.args(); + } + if (args.length == 0) { + return newInstance(annotatedDataFetcher.value()); + } else { + try { + final Constructor ctr = annotatedDataFetcher.value().getDeclaredConstructor( + stream(args).map(v -> String.class).toArray(Class[]::new)); + return constructNewInstance(ctr, (Object[]) args); + } catch (final NoSuchMethodException e) { + throw new GraphQLAnnotationsException("Unable to instantiate DataFetcher via constructor for: " + fieldName, e); + } + } + } + protected GraphQLFieldDefinition field(Field field) throws IllegalAccessException, InstantiationException { return getInstance().getField(field); } @@ -592,7 +596,7 @@ protected GraphQLFieldDefinition getField(Method method) throws GraphQLAnnotatio } else if (dataFetcher == null) { actualDataFetcher = new MethodDataFetcher(method, typeFunction); } else { - actualDataFetcher = newInstance(dataFetcher.value()); + actualDataFetcher = constructDataFetcher(method.getName(), dataFetcher); } if (method.isAnnotationPresent(GraphQLRelayMutation.class) && relay != null) { diff --git a/src/test/java/graphql/annotations/GraphQLDataFetcherTest.java b/src/test/java/graphql/annotations/GraphQLDataFetcherTest.java index 65a9c1f8..15273a7e 100644 --- a/src/test/java/graphql/annotations/GraphQLDataFetcherTest.java +++ b/src/test/java/graphql/annotations/GraphQLDataFetcherTest.java @@ -26,6 +26,7 @@ import java.util.HashMap; import static graphql.schema.GraphQLSchema.newSchema; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertTrue; @@ -48,15 +49,58 @@ public void shouldUsePreferredConstructor() { assertTrue(((HashMap) data.get("sample")).get("isBad")); } + @Test + public void shouldUseProvidedSoloArgumentForDataFetcherDeclaredInMethod() { + // Given + final GraphQLObjectType object = GraphQLAnnotations.object(TestMethodWithDataFetcherGraphQLQuery.class); + final GraphQLSchema schema = newSchema().query(object).build(); + final GraphQL graphql = GraphQL.newGraphQL(schema).build(); + + // When + final ExecutionResult result = graphql.execute("{great}"); + + // Then + final HashMap data = (HashMap) result.getData(); + assertNotNull(data); + assertFalse((Boolean)data.get("great")); + } + + @Test + public void shouldUseTargetAndArgumentsForDataFetcherDeclaredInMethod() { + // Given + final GraphQLObjectType object = GraphQLAnnotations.object(TestMethodWithDataFetcherGraphQLQuery.class); + final GraphQLSchema schema = newSchema().query(object).build(); + final GraphQL graphql = GraphQL.newGraphQL(schema).build(); + + // When + final ExecutionResult result = graphql.execute("{sample {bad}}"); + + // Then + final HashMap data = (HashMap) result.getData(); + assertNotNull(data); + assertTrue(((HashMap)data.get("sample")).get("bad")); + } + @GraphQLName("Query") public static class TestGraphQLQuery { @GraphQLField @GraphQLDataFetcher(SampleDataFetcher.class) public TestSample sample() { // Note that GraphQL uses TestSample to build the graph - return null; + return null; } } + @GraphQLName("Query") + public static class TestMethodWithDataFetcherGraphQLQuery { + @GraphQLField + @GraphQLDataFetcher(value = SampleOneArgDataFetcher.class, args = "true") + public Boolean great() { return false; } + + @GraphQLField + @GraphQLDataFetcher(SampleDataFetcher.class) + public TestSampleMethod sample() { return null; } + } + public static class TestSample { @GraphQLField @GraphQLDataFetcher(value = PropertyDataFetcher.class, args = "isGreat") @@ -68,6 +112,14 @@ public static class TestSample { } + public static class TestSampleMethod { + + @GraphQLField + @GraphQLDataFetcher(value = SampleMultiArgDataFetcher.class, firstArgIsTargetName = true, args = {"true"}) + public Boolean isBad() { return false; } // Defaults to FieldDataFetcher + + } + public static class SampleDataFetcher implements DataFetcher { @Override public Object get(final DataFetchingEnvironment environment) { @@ -75,6 +127,23 @@ public Object get(final DataFetchingEnvironment environment) { } } + public static class SampleOneArgDataFetcher implements DataFetcher { + private boolean flip = false; + + public SampleOneArgDataFetcher(String flip) { + this.flip = Boolean.valueOf(flip); + } + + @Override + public Object get(DataFetchingEnvironment environment) { + if ( flip ) { + return !flip; + } else { + return flip; + } + } + } + public static class SampleMultiArgDataFetcher extends PropertyDataFetcher { private boolean flip = false; @@ -87,7 +156,7 @@ public SampleMultiArgDataFetcher(String target, String flip) { public Object get(DataFetchingEnvironment environment) { final Object result = super.get(environment); if (flip) { - return !(Boolean) result; + return !(Boolean)result; } else { return result; }