diff --git a/graphql-dgs-example-java/src/main/java/com/netflix/graphql/dgs/example/datafetcher/MyInstrumentation.java b/graphql-dgs-example-java/src/main/java/com/netflix/graphql/dgs/example/datafetcher/MyInstrumentation.java new file mode 100644 index 000000000..072984520 --- /dev/null +++ b/graphql-dgs-example-java/src/main/java/com/netflix/graphql/dgs/example/datafetcher/MyInstrumentation.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.graphql.dgs.example.datafetcher; + +import com.netflix.graphql.dgs.mvc.DgsRestController; +import graphql.ExecutionResult; +import graphql.ExecutionResultImpl; +import graphql.execution.instrumentation.SimpleInstrumentation; +import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters; +import org.springframework.stereotype.Component; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +@Component +public class MyInstrumentation extends SimpleInstrumentation { + @Override + public CompletableFuture instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters) { + HashMap extensions = new HashMap<>(); + if(executionResult.getExtensions() != null) { + extensions.putAll(executionResult.getExtensions()); + } + + Map responseHeaders = new HashMap<>(); + responseHeaders.put("myHeader", "hello"); + extensions.put(DgsRestController.DGS_RESPONSE_HEADERS_KEY, responseHeaders); + + return super.instrumentExecutionResult(new ExecutionResultImpl(executionResult.getData(), executionResult.getErrors(), extensions), parameters); + } +} diff --git a/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt b/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt index 7f108a7e7..929024b27 100644 --- a/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt +++ b/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt @@ -75,6 +75,7 @@ open class DgsRestController( ) { companion object { + const val DGS_RESPONSE_HEADERS_KEY = "dgs-response-headers" private val logger: Logger = LoggerFactory.getLogger(DgsRestController::class.java) } @@ -225,6 +226,27 @@ open class DgsRestController( .body("Trying to execute subscription on /graphql. Use /subscriptions instead!") } + val responseHeaders = if (executionResult.extensions?.containsKey(DGS_RESPONSE_HEADERS_KEY) == true) { + val dgsResponseHeaders = executionResult.extensions[DGS_RESPONSE_HEADERS_KEY] + val responseHeaders = HttpHeaders() + if (dgsResponseHeaders is Map<*, *>) { + dgsResponseHeaders.forEach { + if (it.key != null) { + responseHeaders.add(it.key.toString(), it.value?.toString()) + } + } + } else { + logger.warn( + "{} must be of type java.util.Map, but was {}", + DGS_RESPONSE_HEADERS_KEY, + dgsResponseHeaders?.javaClass?.name + ) + } + + executionResult.extensions.remove(DGS_RESPONSE_HEADERS_KEY) + responseHeaders + } else HttpHeaders() + val result = try { TimeTracer.logTime( { mapper.writeValueAsBytes(executionResult.toSpecification()) }, @@ -238,6 +260,6 @@ open class DgsRestController( mapper.writeValueAsBytes(errorResponse.toSpecification()) } - return ResponseEntity.ok(result) + return ResponseEntity(result, responseHeaders, HttpStatus.OK) } } diff --git a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt index f591e16e2..69f884b72 100644 --- a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt +++ b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt @@ -181,6 +181,30 @@ class DgsRestControllerTest { .isEqualTo(HttpStatus.BAD_REQUEST) } + @Test + fun `Writes response headers when dgs-response-headers are set in extensions object`() { + val queryString = "query { hello }" + val requestBody = """ + { + "query": "$queryString" + } + """.trimIndent() + + every { + dgsQueryExecutor.execute( + queryString, + emptyMap(), + any(), + any(), + any(), + any() + ) + } returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hello", "hello"))).extensions(mutableMapOf(Pair(DgsRestController.DGS_RESPONSE_HEADERS_KEY, mapOf(Pair("myHeader", "hello")))) as Map?).build() + + val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest) + assertThat(result.headers["myHeader"]).contains("hello") + } + data class GraphQLResponse(val data: Map = emptyMap(), val errors: List = emptyList()) @JsonIgnoreProperties(ignoreUnknown = true)