Skip to content

Commit

Permalink
GraphQLContext map cannot contain null keys nor values
Browse files Browse the repository at this point in the history
  • Loading branch information
Dariusz Kuc committed Jan 27, 2022
1 parent ea06a54 commit f99d268
Show file tree
Hide file tree
Showing 17 changed files with 43 additions and 39 deletions.
Expand Up @@ -25,13 +25,16 @@ import io.ktor.request.ApplicationRequest
*/
class KtorGraphQLContextFactory : GraphQLContextFactory<ApplicationRequest> {

override suspend fun generateContextMap(request: ApplicationRequest): Map<*, Any?> = mapOf(
override suspend fun generateContextMap(request: ApplicationRequest): Map<Any, Any> = mutableMapOf<Any, Any>(
"user" to User(
email = "fake@site.com",
firstName = "Someone",
lastName = "You Don't know",
universityId = 4
),
"customHeader" to request.headers["my-custom-header"]
)
)
).also { map ->
request.headers["my-custom-header"]?.let { customHeader ->
map["customHeader"] = customHeader
}
}
}
Expand Up @@ -26,7 +26,7 @@ import org.springframework.web.reactive.function.server.ServerRequest
*/
@Component
class MyGraphQLContextFactory : SpringGraphQLContextFactory() {
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any?> = mapOf(
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any> = mapOf(
"myCustomValue" to (request.headers().firstHeader("MyHeader") ?: "defaultContext")
)
}
Expand Up @@ -27,7 +27,7 @@ class MySubscriptionHooks : ApolloSubscriptionHooks {
override fun onConnect(
connectionParams: Map<String, String>,
session: WebSocketSession,
graphQLContext: Map<*, Any?>?
graphQLContext: Map<*, Any>
): Map<*, Any?> = mapOf(
"auth" to connectionParams["Authorization"]
)
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2021 Expedia, Inc
* Copyright 2022 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,13 @@ import graphql.GraphQLError
import graphql.execution.DataFetcherResult
import graphql.schema.DataFetcher
import graphql.schema.DataFetchingEnvironment
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.future.future
import java.util.concurrent.CompletableFuture
import kotlin.coroutines.EmptyCoroutineContext

private const val TYPENAME_FIELD = "__typename"
private const val REPRESENTATIONS = "representations"
Expand Down Expand Up @@ -54,7 +55,8 @@ open class EntityResolver(resolvers: List<FederatedTypeResolver<*>>) : DataFetch
val representations: List<Map<String, Any>> = env.getArgument(REPRESENTATIONS)
val indexedBatchRequestsByType = representations.withIndex().groupBy { it.value[TYPENAME_FIELD].toString() }

return GlobalScope.future {
val scope = env.graphQlContext.getOrDefault(CoroutineScope::class, CoroutineScope(EmptyCoroutineContext))
return scope.future {
val data = mutableListOf<Any?>()
val errors = mutableListOf<GraphQLError>()

Expand Down
Expand Up @@ -22,8 +22,8 @@ package com.expediagroup.graphql.server.execution
interface GraphQLContextFactory<in Request> {

/**
* GraphQL Java 17 has a new context map instead of a generic object. Implementing this method
* will set the context map in the execution input.
* Generate GraphQL context based on the incoming request and the corresponding response.
* If no context should be generated and used in the request, return empty map.
*/
suspend fun generateContextMap(request: Request): Map<*, Any?>? = null
suspend fun generateContextMap(request: Request): Map<*, Any> = emptyMap<Any, Any>()
}
Expand Up @@ -35,7 +35,7 @@ open class GraphQLRequestHandler(
* This should only be used for queries and mutations.
* Subscriptions require more specific server logic and will need to be handled separately.
*/
open suspend fun executeRequest(request: GraphQLRequest, graphQLContext: Map<*, Any?>? = null): GraphQLResponse<*> {
open suspend fun executeRequest(request: GraphQLRequest, graphQLContext: Map<*, Any> = emptyMap<Any, Any>()): GraphQLResponse<*> {
// We should generate a new registry for every request
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val executionInput = request.toExecutionInput(dataLoaderRegistry, graphQLContext)
Expand Down
Expand Up @@ -48,7 +48,7 @@ open class GraphQLServer<Request>(
val graphQLRequest = requestParser.parseRequest(request)

if (graphQLRequest != null) {
val contextMap = contextFactory.generateContextMap(request) ?: emptyMap<Any, Any?>()
val contextMap = contextFactory.generateContextMap(request)

val customContext: CoroutineContext = contextMap[CoroutineContext::class] as? CoroutineContext ?: EmptyCoroutineContext
val graphQLExecutionScope = CoroutineScope(coroutineContext + customContext + SupervisorJob())
Expand Down
Expand Up @@ -58,7 +58,7 @@ class GraphQLServerTest {
coEvery { parseRequest(any()) } returns mockk<GraphQLRequest>()
}
val mockContextFactory = mockk<GraphQLContextFactory<MockHttpRequest>> {
coEvery { generateContextMap(any()) } returns null
coEvery { generateContextMap(any()) } returns emptyMap<Any, Any>()
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), any()) } returns mockk()
Expand All @@ -81,7 +81,7 @@ class GraphQLServerTest {
coEvery { parseRequest(any()) } returns null
}
val mockContextFactory = mockk<GraphQLContextFactory<MockHttpRequest>> {
coEvery { generateContextMap(any()) } returns null
coEvery { generateContextMap(any()) } returns emptyMap<Any, Any>()
}
val mockHandler = mockk<GraphQLRequestHandler> {
coEvery { executeRequest(any(), any()) } returns mockk()
Expand Down
Expand Up @@ -29,9 +29,10 @@ abstract class SpringGraphQLContextFactory : GraphQLContextFactory<ServerRequest
* Basic implementation of [SpringGraphQLContextFactory] that populates Apollo tracing header.
*/
class DefaultSpringGraphQLContextFactory : SpringGraphQLContextFactory() {
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any?> = request.headers().firstHeader(FEDERATED_TRACING_HEADER_NAME)?.let { headerValue ->
mapOf(
FEDERATED_TRACING_HEADER_NAME to headerValue
)
} ?: emptyMap<Any, Any?>()
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any> = mutableMapOf<Any, Any>()
.also { map ->
request.headers().firstHeader(FEDERATED_TRACING_HEADER_NAME)?.let { headerValue ->
map[FEDERATED_TRACING_HEADER_NAME] = headerValue
}
}
}
Expand Up @@ -31,8 +31,8 @@ interface ApolloSubscriptionHooks {
fun onConnect(
connectionParams: Map<String, String>,
session: WebSocketSession,
graphQLContext: Map<*, Any?>?
): Map<*, Any?>? = graphQLContext
graphQLContext: Map<*, Any>
): Map<*, Any> = graphQLContext

/**
* Called when the client executes a GraphQL operation.
Expand All @@ -41,7 +41,7 @@ interface ApolloSubscriptionHooks {
fun onOperation(
operationMessage: SubscriptionOperationMessage,
session: WebSocketSession,
graphQLContext: Map<*, Any?>?
graphQLContext: Map<*, Any>
): Unit = Unit

/**
Expand Down
Expand Up @@ -31,23 +31,21 @@ internal class ApolloSubscriptionSessionState {
internal val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Subscription>>()

// The graphQL context is saved by web socket session id
private val cachedGraphQLContext = ConcurrentHashMap<String, Map<*, Any?>>()
private val cachedGraphQLContext = ConcurrentHashMap<String, Map<*, Any>>()

/**
* Save the context created from the factory and possibly updated in the onConnect hook.
* This allows us to include some intial state to be used when handling all the messages.
* This will be removed in [terminateSession].
*/
fun saveGraphQLContext(session: WebSocketSession, graphQLContext: Map<*, Any?>?) {
if (graphQLContext != null) {
cachedGraphQLContext[session.id] = graphQLContext
}
fun saveGraphQLContext(session: WebSocketSession, graphQLContext: Map<*, Any>) {
cachedGraphQLContext[session.id] = graphQLContext
}

/**
* Return the graphQL context for this session.
*/
fun getGraphQLContext(session: WebSocketSession): Map<*, Any?>? = cachedGraphQLContext[session.id]
fun getGraphQLContext(session: WebSocketSession): Map<*, Any> = cachedGraphQLContext[session.id] ?: emptyMap<Any, Any>()

/**
* Save the session that is sending keep alive messages.
Expand Down
Expand Up @@ -37,7 +37,7 @@ open class SpringGraphQLSubscriptionHandler(
private val dataLoaderRegistryFactory: DataLoaderRegistryFactory? = null
) {

fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContextMap: Map<*, Any?>? = null): Flow<GraphQLResponse<*>> {
fun executeSubscription(graphQLRequest: GraphQLRequest, graphQLContextMap: Map<*, Any> = emptyMap<Any, Any>()): Flow<GraphQLResponse<*>> {
val dataLoaderRegistry = dataLoaderRegistryFactory?.generate()
val input = graphQLRequest.toExecutionInput(dataLoaderRegistry, graphQLContextMap)

Expand Down
Expand Up @@ -29,5 +29,5 @@ abstract class SpringSubscriptionGraphQLContextFactory : GraphQLContextFactory<W
*/
class DefaultSpringSubscriptionGraphQLContextFactory : SpringSubscriptionGraphQLContextFactory() {

override suspend fun generateContextMap(request: WebSocketSession): Map<*, Any?>? = null
override suspend fun generateContextMap(request: WebSocketSession): Map<*, Any> = emptyMap<Any, Any>()
}
Expand Up @@ -64,7 +64,7 @@ class GraphQLContextFactoryIT(@Autowired private val testClient: WebTestClient)
@Bean
@ExperimentalCoroutinesApi
fun customContextFactory(): SpringGraphQLContextFactory = object : SpringGraphQLContextFactory() {
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any?> = mapOf(
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any> = mapOf(
"first" to (request.headers().firstHeader("X-First-Header") ?: "DEFAULT_FIRST"),
"second" to (request.headers().firstHeader("X-Second-Header") ?: "DEFAULT_SECOND")
)
Expand Down
Expand Up @@ -53,7 +53,7 @@ class RouteConfigurationIT(@Autowired private val testClient: WebTestClient) {

@Bean
fun customContextFactory(): SpringGraphQLContextFactory = object : SpringGraphQLContextFactory() {
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any?> = mapOf(
override suspend fun generateContextMap(request: ServerRequest): Map<*, Any> = mapOf(
"value" to (request.headers().firstHeader("X-Custom-Header") ?: "default")
)
}
Expand Down
Expand Up @@ -53,7 +53,7 @@ class ApolloSubscriptionProtocolHandlerTest {
private val subscriptionHooks = SimpleSubscriptionHooks()
private fun SubscriptionOperationMessage.toJson() = objectMapper.writeValueAsString(this)
private val nullContextFactory: SpringSubscriptionGraphQLContextFactory = mockk {
coEvery { generateContextMap(any()) } returns null
coEvery { generateContextMap(any()) } returns emptyMap<Any, Any>()
}
private val simpleInitMessage = SubscriptionOperationMessage(GQL_CONNECTION_INIT.type)

Expand Down Expand Up @@ -460,7 +460,7 @@ class ApolloSubscriptionProtocolHandlerTest {
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk()
val subscriptionHooks: ApolloSubscriptionHooks = mockk {
every { onConnect(any(), any(), any()) } returns null
every { onConnect(any(), any(), any()) } returns emptyMap<Any, Any>()
}
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
val flux = handler.handle(simpleInitMessage.toJson(), session)
Expand All @@ -482,7 +482,7 @@ class ApolloSubscriptionProtocolHandlerTest {
}
val subscriptionHandler: SpringGraphQLSubscriptionHandler = mockk()
val subscriptionHooks: ApolloSubscriptionHooks = mockk {
every { onConnect(any(), any(), any()) } returns null
every { onConnect(any(), any(), any()) } returns emptyMap<Any, Any>()
}
val handler = ApolloSubscriptionProtocolHandler(config, nullContextFactory, subscriptionHandler, objectMapper, subscriptionHooks)
val flux = handler.handle(operationMessage, session)
Expand All @@ -507,7 +507,7 @@ class ApolloSubscriptionProtocolHandlerTest {
every { executeSubscription(eq(graphQLRequest), any()) } returns flowOf(expectedResponse)
}
val subscriptionHooks: ApolloSubscriptionHooks = mockk {
every { onConnect(any(), any(), any()) } returns null
every { onConnect(any(), any(), any()) } returns emptyMap<Any, Any>()
every { onOperation(any(), any(), any()) } returns Unit
every { onOperationComplete(any()) } returns Unit
}
Expand Down
Expand Up @@ -185,7 +185,7 @@ class SubscriptionWebSocketHandlerIT(
}

class CustomContextFactory : SpringSubscriptionGraphQLContextFactory() {
override suspend fun generateContextMap(request: WebSocketSession): Map<*, Any?> = mapOf(
override suspend fun generateContextMap(request: WebSocketSession): Map<*, Any> = mapOf(
"value" to (request.handshakeInfo.headers.getFirst("X-Custom-Header") ?: "default")
)
}
Expand Down

0 comments on commit f99d268

Please sign in to comment.