diff --git a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt index 1bf7c83938..98df382ed5 100644 --- a/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt +++ b/executions/graphql-kotlin-dataloader-instrumentation/src/main/kotlin/com/expediagroup/graphql/dataloader/instrumentation/extensions/CompletableFutureExtensions.kt @@ -27,16 +27,14 @@ import org.dataloader.DataLoader import java.util.concurrent.CompletableFuture /** - * Check if all futures collected on [KotlinDataLoaderRegistry.dispatchAll] were handled and we have more futures than we - * had when we started to dispatch, if so, means that [DataLoader]s were chained + * Check if all futures collected on [KotlinDataLoaderRegistry.dispatchAll] were handled + * and if we have more futures than we had when we started to dispatch, if so, + * means that [DataLoader]s were chained, so we need to dispatch the dataLoaderRegistry. */ fun CompletableFuture.dispatchIfNeeded( environment: DataFetchingEnvironment ): CompletableFuture { - val dataLoaderRegistry = - environment - .graphQlContext.get(KotlinDataLoaderRegistry::class) - ?: throw MissingKotlinDataLoaderRegistryException() + val dataLoaderRegistry = environment.dataLoaderRegistry as? KotlinDataLoaderRegistry ?: throw MissingKotlinDataLoaderRegistryException() if (dataLoaderRegistry.dataLoadersInvokedOnDispatch()) { val cantContinueExecution = when { diff --git a/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt b/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt index 837130db7c..07fe8cc357 100644 --- a/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt +++ b/executions/graphql-kotlin-dataloader-instrumentation/src/test/kotlin/com/expediagroup/graphql/dataloader/instrumentation/syncexhaustion/DataLoaderSyncExecutionExhaustedInstrumentationTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2022 Expedia, Inc + * Copyright 2023 Expedia, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import io.mockk.verify import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import kotlin.test.assertEquals +import kotlin.test.assertTrue class DataLoaderSyncExecutionExhaustedInstrumentationTest { private val dataLoaderSyncExecutionExhaustedInstrumentation = spyk(DataLoaderSyncExecutionExhaustedInstrumentation()) @@ -389,6 +390,10 @@ class DataLoaderSyncExecutionExhaustedInstrumentationTest { assertEquals(3, results.size) + results.forEach { result -> + assertTrue(result.errors.isEmpty()) + } + val missionStatistics = kotlinDataLoaderRegistry.dataLoadersMap["MissionDataLoader"]?.statistics val planetStatistics = kotlinDataLoaderRegistry.dataLoadersMap["PlanetsByMissionDataLoader"]?.statistics @@ -426,10 +431,15 @@ class DataLoaderSyncExecutionExhaustedInstrumentationTest { ) assertEquals(3, results.size) + results.forEach { result -> + assertTrue(result.errors.isEmpty()) + } + val astronautStatistics = kotlinDataLoaderRegistry.dataLoadersMap["AstronautDataLoader"]?.statistics val missionsByAstronautStatistics = kotlinDataLoaderRegistry.dataLoadersMap["MissionsByAstronautDataLoader"]?.statistics val planetStatistics = kotlinDataLoaderRegistry.dataLoadersMap["PlanetsByMissionDataLoader"]?.statistics + assertEquals(1, astronautStatistics?.batchInvokeCount) assertEquals(1, missionsByAstronautStatistics?.batchInvokeCount) assertEquals(1, planetStatistics?.batchInvokeCount) } @@ -468,11 +478,16 @@ class DataLoaderSyncExecutionExhaustedInstrumentationTest { DataLoaderInstrumentationStrategy.SYNC_EXHAUSTION ) + val astronautStatistics = kotlinDataLoaderRegistry.dataLoadersMap["AstronautDataLoader"]?.statistics val missionsByAstronautStatistics = kotlinDataLoaderRegistry.dataLoadersMap["MissionsByAstronautDataLoader"]?.statistics val planetStatistics = kotlinDataLoaderRegistry.dataLoadersMap["PlanetsByMissionDataLoader"]?.statistics assertEquals(2, results.size) + results.forEach { result -> + assertTrue(result.errors.isEmpty()) + } + assertEquals(1, astronautStatistics?.batchInvokeCount) assertEquals(1, missionsByAstronautStatistics?.batchInvokeCount) assertEquals(1, planetStatistics?.batchInvokeCount) }