Skip to content

Commit

Permalink
Fix completablefuture wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbakker committed Oct 12, 2023
1 parent 9d446eb commit b4516d6
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 Netflix, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.netflix.graphql.dgs.internal

import org.reactivestreams.Publisher
import org.springframework.core.task.AsyncTaskExecutor
import java.lang.reflect.Method
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import kotlin.reflect.KFunction
import kotlin.reflect.KType
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.typeOf

internal class CompletableFutureWrapper(private val taskExecutor: AsyncTaskExecutor?) {
private val supportsReactor: Boolean = try {
Class.forName("org.reactivestreams.Publisher")
true
} catch (ex: Exception) {
false
}

/**
* Wrap the call to a data fetcher in CompletableFuture to enable parallel behavior.
* Used when virtual threads are enabled.
*/
fun wrapInCompletableFuture(function: () -> Any?): Any? {
return CompletableFuture.supplyAsync({
return@supplyAsync function.invoke()
}, taskExecutor)
}

/**
* Decides if a data fetcher method should be wrapped in CompletableFuture automatically.
* This is only done when a taskExecutor is available, and if the data fetcher doesn't explicitly return CompletableFuture already.
* Used when virtual threads are enabled.
*/
fun shouldWrapInCompletableFuture(kFunc: KFunction<*>): Boolean {
return taskExecutor != null &&
!kFunc.returnType.isSubtypeOf(typeOf<CompletionStage<*>>()) &&
!isReactive(kFunc.returnType)
}

private fun isReactive(returnType: KType): Boolean {
return supportsReactor && returnType.isSubtypeOf(typeOf<Publisher<*>>())
}

/**
* Decides if a data fetcher method should be wrapped in CompletableFuture automatically.
* This is only done when a taskExecutor is available, and if the data fetcher doesn't explicitly return CompletableFuture already.
* Used when virtual threads are enabled.
*/
fun shouldWrapInCompletableFuture(method: Method): Boolean {
return taskExecutor != null &&
!CompletionStage::class.java.isAssignableFrom(method.returnType) &&
!isReactive(method.returnType)
}

private fun isReactive(returnType: Class<*>): Boolean {
return supportsReactor && Publisher::class.java.isAssignableFrom(returnType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,22 @@ import org.springframework.util.CollectionUtils
import org.springframework.util.ReflectionUtils
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionStage
import kotlin.reflect.KFunction
import kotlin.reflect.KParameter
import kotlin.reflect.full.callSuspendBy
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.jvm.kotlinFunction
import kotlin.reflect.typeOf

class DataFetcherInvoker internal constructor(
private val dgsComponent: Any,
method: Method,
private val resolvers: ArgumentResolverComposite,
parameterNameDiscoverer: ParameterNameDiscoverer,
private val taskExecutor: AsyncTaskExecutor?
taskExecutor: AsyncTaskExecutor?
) : DataFetcher<Any?> {

private val bridgedMethod: Method = BridgeMethodResolver.findBridgedMethod(method)
private val kotlinFunction: KFunction<*>? = bridgedMethod.kotlinFunction
private val completableFutureWrapper = CompletableFutureWrapper(taskExecutor)

private val methodParameters: List<MethodParameter> = bridgedMethod.parameters.map { parameter ->
val methodParameter = SynthesizingMethodParameter.forParameter(parameter)
Expand All @@ -62,8 +59,8 @@ class DataFetcherInvoker internal constructor(

override fun get(environment: DataFetchingEnvironment): Any? {
if (methodParameters.isEmpty()) {
if (shouldWrapInCompletableFuture(bridgedMethod)) {
return wrapInCompletableFuture { ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent) }
if (completableFutureWrapper.shouldWrapInCompletableFuture(bridgedMethod)) {
return completableFutureWrapper.wrapInCompletableFuture { ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent) }
}
return ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent)
}
Expand All @@ -81,8 +78,8 @@ class DataFetcherInvoker internal constructor(
args[idx] = resolvers.resolveArgument(parameter, environment)
}

return if (shouldWrapInCompletableFuture(bridgedMethod)) {
wrapInCompletableFuture { ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent, *args) }
return if (completableFutureWrapper.shouldWrapInCompletableFuture(bridgedMethod)) {
completableFutureWrapper.wrapInCompletableFuture { ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent, *args) }
} else {
ReflectionUtils.invokeMethod(bridgedMethod, dgsComponent, *args)
}
Expand Down Expand Up @@ -116,8 +113,8 @@ class DataFetcherInvoker internal constructor(
}.onErrorMap(InvocationTargetException::class.java) { it.targetException }
}
return try {
if (shouldWrapInCompletableFuture(kFunc)) {
wrapInCompletableFuture { kFunc.callBy(argsByName) }
if (completableFutureWrapper.shouldWrapInCompletableFuture(kFunc)) {
completableFutureWrapper.wrapInCompletableFuture { kFunc.callBy(argsByName) }
} else {
kFunc.callBy(argsByName)
}
Expand All @@ -130,32 +127,4 @@ class DataFetcherInvoker internal constructor(
return "Could not resolve parameter [${param.parameterIndex}] in " +
param.executable.toGenericString() + if (message.isNotEmpty()) ": $message" else ""
}

/**
* Wrap the call to a data fetcher in CompletableFuture to enable parallel behavior.
* Used when virtual threads are enabled.
*/
private fun wrapInCompletableFuture(function: () -> Any?): Any? {
return CompletableFuture.supplyAsync({
return@supplyAsync function.invoke()
}, taskExecutor)
}

/**
* Decides if a data fetcher method should be wrapped in CompletableFuture automatically.
* This is only done when a taskExecutor is available, and if the data fetcher doesn't explicitly return CompletableFuture already.
* Used when virtual threads are enabled.
*/
private fun shouldWrapInCompletableFuture(kFunc: KFunction<*>): Boolean {
return taskExecutor != null && !kFunc.returnType.isSubtypeOf(typeOf<CompletionStage<Any>>())
}

/**
* Decides if a data fetcher method should be wrapped in CompletableFuture automatically.
* This is only done when a taskExecutor is available, and if the data fetcher doesn't explicitly return CompletableFuture already.
* Used when virtual threads are enabled.
*/
private fun shouldWrapInCompletableFuture(method: Method): Boolean {
return taskExecutor != null && !CompletionStage::class.java.isAssignableFrom(method.returnType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package com.netflix.graphql.dgs.internal

import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.core.task.AsyncTaskExecutor
import reactor.core.publisher.Mono
import java.util.concurrent.CompletableFuture
import java.util.function.Function
import kotlin.reflect.jvm.ExperimentalReflectionOnLambdas
import kotlin.reflect.jvm.reflect

@OptIn(ExperimentalReflectionOnLambdas::class)
@ExtendWith(MockKExtension::class)
class CompletableFutureWrapperTest {
@MockK(relaxUnitFun = true)
lateinit var mockTaskExecutor: AsyncTaskExecutor

@Test
fun `If no taskExecutor is set, no wrapping should happen`() {
val completableFutureWrapper = CompletableFutureWrapper(null)

// Check Kotlin method
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(fun(): String { return "hello" }.reflect()!!)).isFalse()

// Check Java method
val stringMethod = String::class.java.getMethod("toString")
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(stringMethod)).isFalse()
}

@Test
fun `A Kotlin String function should be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(fun(): String { return "hello" }.reflect()!!)).isTrue()
}

@Test
fun `A Kotlin CompletableFuture function should not be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(fun(): CompletableFuture<String> { return CompletableFuture() }.reflect()!!)).isFalse()
}

@Test
fun `A Kotlin Mono function should not be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(fun(): Mono<String> { return Mono.just("hi") }.reflect()!!)).isFalse()
}

@Test
fun `A Java String method should be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
val stringMethod = String::class.java.getMethod("toString")
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(stringMethod)).isTrue()
}

@Test
fun `A Java CompletableFuture method should not be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
val cfMethod = CompletableFuture::class.java.getMethod("thenApplyAsync", Function::class.java)
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(cfMethod)).isFalse()
}

@Test
fun `A Java Mono method should not be wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(VirtualThreadTaskExecutor())
val monoMethod = Mono::class.java.getMethod("empty")
assertThat(completableFutureWrapper.shouldWrapInCompletableFuture(monoMethod)).isFalse()
}

@Test
fun `A method should successfully get wrapped`() {
val completableFutureWrapper = CompletableFutureWrapper(mockTaskExecutor)
val wrapped = completableFutureWrapper.wrapInCompletableFuture { fun(): String { return "hello" } }
assertThat(wrapped).isInstanceOf(CompletableFuture::class.java)
}
}

0 comments on commit b4516d6

Please sign in to comment.