Skip to content

Commit

Permalink
Merge pull request #536 from hantsy/master
Browse files Browse the repository at this point in the history
add @DsgDirective
  • Loading branch information
paulbakker committed Aug 19, 2021
2 parents 916277e + a653449 commit 917bb29
Show file tree
Hide file tree
Showing 19 changed files with 333 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.jayway.jsonpath.TypeRef
import com.jayway.jsonpath.spi.mapper.MappingException
import com.netflix.graphql.dgs.DgsComponent
import com.netflix.graphql.dgs.DgsData
import com.netflix.graphql.dgs.DgsDirective
import com.netflix.graphql.dgs.DgsScalar
import com.netflix.graphql.dgs.exceptions.DgsQueryExecutionDataExtractionException
import com.netflix.graphql.dgs.exceptions.QueryException
Expand Down Expand Up @@ -100,6 +101,7 @@ internal class DefaultDgsReactiveQueryExecutorTest {
LocalDateTimeScalar()
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { dgsDataLoaderProvider.buildRegistryWithContextSupplier(any<Supplier<Any>>()) } returns DataLoaderRegistry()

val provider = DgsSchemaProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.netflix.graphql.dgs.reactive

import com.netflix.graphql.dgs.DgsComponent
import com.netflix.graphql.dgs.DgsData
import com.netflix.graphql.dgs.DgsDirective
import com.netflix.graphql.dgs.DgsScalar
import com.netflix.graphql.dgs.exceptions.QueryException
import com.netflix.graphql.dgs.internal.DgsDataLoaderProvider
Expand Down Expand Up @@ -102,6 +103,7 @@ internal class ReactiveReturnTypesTest {
LocalDateTimeScalar()
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { dgsDataLoaderProvider.buildRegistryWithContextSupplier(any<Supplier<Any>>()) } returns DataLoaderRegistry()

val provider = DgsSchemaProvider(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2021 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs;

import org.springframework.stereotype.Component;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Mark a class as a custom Directive implementation that gets registered to the framework.
*/
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface DgsDirective {
String name() default "";
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ import graphql.language.InterfaceTypeDefinition
import graphql.language.TypeName
import graphql.language.UnionTypeDefinition
import graphql.schema.*
import graphql.schema.idl.RuntimeWiring
import graphql.schema.idl.SchemaParser
import graphql.schema.idl.TypeDefinitionRegistry
import graphql.schema.idl.TypeRuntimeWiring
import graphql.schema.idl.*
import graphql.schema.visibility.DefaultGraphqlFieldVisibility
import graphql.schema.visibility.GraphqlFieldVisibility
import org.slf4j.LoggerFactory
Expand Down Expand Up @@ -108,6 +105,7 @@ class DgsSchemaProvider(
.mapNotNull { dgsComponent -> invokeDgsTypeDefinitionRegistry(dgsComponent, mergedRegistry) }
.fold(mergedRegistry) { a, b -> a.merge(b) }
findScalars(applicationContext, runtimeWiringBuilder)
findDirectives(applicationContext, runtimeWiringBuilder)
findDataFetchers(dgsComponents, codeRegistryBuilder, mergedRegistry)
findTypeResolvers(dgsComponents, runtimeWiringBuilder, mergedRegistry)
findEntityFetchers(dgsComponents)
Expand Down Expand Up @@ -387,6 +385,21 @@ class DgsSchemaProvider(
}
}

private fun findDirectives(applicationContext: ApplicationContext, runtimeWiringBuilder: RuntimeWiring.Builder) {
applicationContext.getBeansWithAnnotation(DgsDirective::class.java).forEach { (_, directiveComponent) ->
val annotation = directiveComponent::class.java.getAnnotation(DgsDirective::class.java)
when (directiveComponent) {
is SchemaDirectiveWiring ->
if (annotation.name.isNotBlank()) {
runtimeWiringBuilder.directive(annotation.name, directiveComponent)
} else {
runtimeWiringBuilder.directiveWiring(directiveComponent)
}
else -> throw RuntimeException("Invalid @DgsDirective type: the class must implement graphql.schema.idl.SchemaDirectiveWiring")
}
}
}

internal fun findSchemaFiles(hasDynamicTypeRegistry: Boolean = false): List<Resource> {
val cl = Thread.currentThread().contextClassLoader

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CoroutineDataFetcherTest {
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())
val schema = provider.schema(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright 2021 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs

import com.netflix.graphql.dgs.internal.DgsSchemaProvider
import graphql.GraphQL
import io.mockk.every
import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.context.ApplicationContext
import java.util.*

@ExtendWith(MockKExtension::class)
class CustomDirectivesTest {
@MockK
lateinit var applicationContextMock: ApplicationContext

@Test
fun testCustomDirectives() {
val fetcher = object : Any() {
@DgsData(parentType = "Query", field = "hello")
fun hello(): String = "hello"

@DgsData(parentType = "Query", field = "word")
fun word(): String = "abcefg"
}

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(
Pair(
"helloFetcher",
fetcher
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns mapOf()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns mapOf(
Pair(
"uppercase",
UppercaseDirective()
),
Pair(
"wordfilter",
WordFilterDirective()
)
)

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())

val schema = provider.schema(
"""
type Query {
hello: String @uppercase
word: String
}
directive @uppercase on FIELD_DEFINITION
""".trimIndent()
)

val build = GraphQL.newGraphQL(schema).build()
val executionResult = build.execute(
"""
{
hello
}
""".trimIndent()
)

assertEquals(0, executionResult.errors.size)
val data = executionResult.getData<Map<String, String>>()
assertThat(data["hello"]).isEqualTo("HELLO")

// test global directive
val wordExecutionResult = build.execute(
"""
{
word
}
""".trimIndent()
)

assertEquals(0, wordExecutionResult.errors.size)
val wordData = wordExecutionResult.getData<Map<String, String>>()
assertThat(wordData["word"]).contains("xxx")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,27 @@ class CustomScalarsTest {
}
}

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("timeFetcher", fetcher))
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns mapOf(Pair("localDateTimeScalar", LocalDateTimeScalar()))
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(
Pair(
"timeFetcher",
fetcher
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns mapOf(
Pair(
"localDateTimeScalar",
LocalDateTimeScalar()
)
)
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())

val schema = provider.schema(
"""
type Query {
now: DateTime
schedule(time: DateTime): Boolean
schedule(time: DateTime): Boolean
}
scalar DateTime
Expand All @@ -78,6 +89,8 @@ class CustomScalarsTest {

Assertions.assertEquals(0, executionResult.errors.size)
val data = executionResult.getData<Map<String, String>>()
Assertions.assertTrue(LocalDateTime.parse(data["now"], DateTimeFormatter.ISO_DATE_TIME).plusHours(1).isAfter(LocalDateTime.now()))
Assertions.assertTrue(
LocalDateTime.parse(data["now"], DateTimeFormatter.ISO_DATE_TIME).plusHours(1).isAfter(LocalDateTime.now())
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class DataFetcherWithDirectivesTest {

every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", queryFetcher))
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()

val provider = DgsSchemaProvider(applicationContextMock, Optional.empty(), Optional.empty(), Optional.empty())
val schema = provider.schema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal class DgsDataFetchingEnvironmentTest {
fun getDataLoader() {
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcher))
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns mapOf(Pair("helloLoader", ExampleBatchLoader()))

every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
val provider = DgsDataLoaderProvider(applicationContextMock)
provider.findDataLoaders()
val dataLoaderRegistry = provider.buildRegistry()
Expand All @@ -141,7 +141,7 @@ internal class DgsDataFetchingEnvironmentTest {
fun getDataLoaderWithBasicDFE() {
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcherWithBasicDFE))
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns mapOf(Pair("helloLoader", ExampleBatchLoader()))

every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
val provider = DgsDataLoaderProvider(applicationContextMock)
provider.findDataLoaders()
val dataLoaderRegistry = provider.buildRegistry()
Expand All @@ -163,6 +163,7 @@ internal class DgsDataFetchingEnvironmentTest {
@Test
fun getDataLoaderFromField() {
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcherWithField), Pair("helloLoader", ExampleBatchLoaderFromField()))

val provider = DgsDataLoaderProvider(applicationContextMock)
Expand All @@ -186,6 +187,7 @@ internal class DgsDataFetchingEnvironmentTest {
@Test
fun getMultipleDataLoadersFromField() {
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcherWithMultipleField), Pair("helloLoader", ExampleMultipleBatchLoadersAsField()))

val provider = DgsDataLoaderProvider(applicationContextMock)
Expand All @@ -208,7 +210,7 @@ internal class DgsDataFetchingEnvironmentTest {
fun getMappedDataLoader() {
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcherMapped))
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns mapOf(Pair("helloLoader", ExampleMappedBatchLoader()))

every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
val provider = DgsDataLoaderProvider(applicationContextMock)
provider.findDataLoaders()
val dataLoaderRegistry = provider.buildRegistry()
Expand All @@ -230,6 +232,7 @@ internal class DgsDataFetchingEnvironmentTest {
@Test
fun getMappedDataLoaderFromField() {
every { applicationContextMock.getBeansWithAnnotation(DgsDataLoader::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("helloFetcher", helloFetcherWithFieldMapped), Pair("helloLoader", ExampleMappedBatchLoaderFromField()))

val provider = DgsDataLoaderProvider(applicationContextMock)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class DgsFederationResolverTest {
}

every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("MovieEntityFetcher", movieEntityFetcher))
dgsSchemaProvider.schema("""type Query {}""")

Expand Down Expand Up @@ -222,6 +223,7 @@ class DgsFederationResolverTest {
}
}
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("MovieEntityFetcher", movieEntityFetcher))
dgsSchemaProvider.schema("""type Query {}""")

Expand All @@ -245,6 +247,7 @@ class DgsFederationResolverTest {
}
}
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("MovieEntityFetcher", movieEntityFetcher))
dgsSchemaProvider.schema("""type Query {}""")

Expand All @@ -270,6 +273,7 @@ class DgsFederationResolverTest {

private fun testEntityFetcher(movieEntityFetcher: Any) {
every { applicationContextMock.getBeansWithAnnotation(DgsScalar::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsDirective::class.java) } returns emptyMap()
every { applicationContextMock.getBeansWithAnnotation(DgsComponent::class.java) } returns mapOf(Pair("MovieEntityFetcher", movieEntityFetcher))
dgsSchemaProvider.schema("""type Query {}""")

Expand Down

0 comments on commit 917bb29

Please sign in to comment.