Skip to content

Commit

Permalink
Merge branch 'master' into Netflix#1058
Browse files Browse the repository at this point in the history
  • Loading branch information
antholeole committed Oct 6, 2022
2 parents e679625 + 52508df commit 69a6dd5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
@@ -0,0 +1,45 @@
/*
* Copyright 2022 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs.example.datafetcher;

import com.netflix.graphql.dgs.mvc.DgsRestController;
import graphql.ExecutionResult;
import graphql.ExecutionResultImpl;
import graphql.execution.instrumentation.SimpleInstrumentation;
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters;
import org.springframework.stereotype.Component;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

@Component
public class MyInstrumentation extends SimpleInstrumentation {
@Override
public CompletableFuture<ExecutionResult> instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters) {
HashMap<Object, Object> extensions = new HashMap<>();
if(executionResult.getExtensions() != null) {
extensions.putAll(executionResult.getExtensions());
}

Map<String, String> responseHeaders = new HashMap<>();
responseHeaders.put("myHeader", "hello");
extensions.put(DgsRestController.DGS_RESPONSE_HEADERS_KEY, responseHeaders);

return super.instrumentExecutionResult(new ExecutionResultImpl(executionResult.getData(), executionResult.getErrors(), extensions), parameters);
}
}
Expand Up @@ -75,6 +75,7 @@ open class DgsRestController(
) {

companion object {
const val DGS_RESPONSE_HEADERS_KEY = "dgs-response-headers"
private val logger: Logger = LoggerFactory.getLogger(DgsRestController::class.java)
}

Expand Down Expand Up @@ -225,6 +226,27 @@ open class DgsRestController(
.body("Trying to execute subscription on /graphql. Use /subscriptions instead!")
}

val responseHeaders = if (executionResult.extensions?.containsKey(DGS_RESPONSE_HEADERS_KEY) == true) {
val dgsResponseHeaders = executionResult.extensions[DGS_RESPONSE_HEADERS_KEY]
val responseHeaders = HttpHeaders()
if (dgsResponseHeaders is Map<*, *>) {
dgsResponseHeaders.forEach {
if (it.key != null) {
responseHeaders.add(it.key.toString(), it.value?.toString())
}
}
} else {
logger.warn(
"{} must be of type java.util.Map, but was {}",
DGS_RESPONSE_HEADERS_KEY,
dgsResponseHeaders?.javaClass?.name
)
}

executionResult.extensions.remove(DGS_RESPONSE_HEADERS_KEY)
responseHeaders
} else HttpHeaders()

val result = try {
TimeTracer.logTime(
{ mapper.writeValueAsBytes(executionResult.toSpecification()) },
Expand All @@ -238,6 +260,6 @@ open class DgsRestController(
mapper.writeValueAsBytes(errorResponse.toSpecification())
}

return ResponseEntity.ok(result)
return ResponseEntity(result, responseHeaders, HttpStatus.OK)
}
}
Expand Up @@ -181,6 +181,30 @@ class DgsRestControllerTest {
.isEqualTo(HttpStatus.BAD_REQUEST)
}

@Test
fun `Writes response headers when dgs-response-headers are set in extensions object`() {
val queryString = "query { hello }"
val requestBody = """
{
"query": "$queryString"
}
""".trimIndent()

every {
dgsQueryExecutor.execute(
queryString,
emptyMap(),
any(),
any(),
any(),
any()
)
} returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hello", "hello"))).extensions(mutableMapOf(Pair(DgsRestController.DGS_RESPONSE_HEADERS_KEY, mapOf(Pair("myHeader", "hello")))) as Map<Any, Any>?).build()

val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)
assertThat(result.headers["myHeader"]).contains("hello")
}

data class GraphQLResponse(val data: Map<String, Any> = emptyMap(), val errors: List<GraphQLError> = emptyList())

@JsonIgnoreProperties(ignoreUnknown = true)
Expand Down

0 comments on commit 69a6dd5

Please sign in to comment.