From f4886d2cdc93ac8daac73c931a25cc3a813095d0 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Tue, 10 Jan 2023 09:41:13 +0100 Subject: [PATCH] Fixed ScopedServices handling on resolvers (#5671) --- .../Extensions/SnapshotExtensions.cs | 24 ++++++++-- .../src/Abstractions/WellKnownContextData.cs | 2 - ...estExecutorBuilderExtensions.DataLoader.cs | 1 - .../Processing/MiddlewareContext.Global.cs | 2 + .../src/Fetching/DefaultDataLoaderRegistry.cs | 8 ++-- .../DataLoaderResolverContextExtensions.cs | 10 ++-- .../Expressions/Parameters/ServiceHelper.cs | 8 ++-- .../src/Types/Resolvers/IResolverContext.cs | 10 +++- .../src/Types/Types/DirectiveCollection.cs | 6 +-- .../QueryableCursorPagingProviderTests.cs | 2 + .../UseServiceScopeAttributeTests.cs | 48 ++++++++++++++++--- ...ojectionObjectFieldDescriptorExtensions.cs | 2 + .../DbContextParameterExpressionBuilder.cs | 2 +- ...rameworkObjectFieldDescriptorExtensions.cs | 6 +-- 14 files changed, 97 insertions(+), 34 deletions(-) diff --git a/src/CookieCrumble/src/CookieCrumble/Extensions/SnapshotExtensions.cs b/src/CookieCrumble/src/CookieCrumble/Extensions/SnapshotExtensions.cs index 44f06ac561d..1fdb28ecb1e 100644 --- a/src/CookieCrumble/src/CookieCrumble/Extensions/SnapshotExtensions.cs +++ b/src/CookieCrumble/src/CookieCrumble/Extensions/SnapshotExtensions.cs @@ -7,6 +7,12 @@ namespace CookieCrumble; public static class SnapshotExtensions { + public static void MatchInlineSnapshot( + this object? value, + string snapshot, + ISnapshotValueFormatter? formatter = null) + => Snapshot.Create().Add(value, formatter: formatter).MatchInline(snapshot); + public static void MatchSnapshot(this Snapshot value) => value.Match(); @@ -54,7 +60,11 @@ public static void MatchSnapshot(this Snapshot value) return snapshot; } - snapshot.Add(result.ToJson(), string.IsNullOrEmpty(name) ? "Result:" : $"{name} Result:"); + snapshot.Add( + result.ToJson(), + string.IsNullOrEmpty(name) + ? "Result:" + : $"{name} Result:"); snapshot.SetPostFix(TestEnvironment.TargetFramework); if (result.ContextData.TryGetValue("query", out var queryResult) && @@ -63,7 +73,9 @@ public static void MatchSnapshot(this Snapshot value) { snapshot.Add( queryString, - string.IsNullOrEmpty(name) ? "Query:" : $"{name} Query:", + string.IsNullOrEmpty(name) + ? "Query:" + : $"{name} Query:", SnapshotValueFormatters.PlainText); } @@ -71,7 +83,9 @@ public static void MatchSnapshot(this Snapshot value) { snapshot.Add( sql, - string.IsNullOrEmpty(name) ? "SQL:" : $"{name} SQL:", + string.IsNullOrEmpty(name) + ? "SQL:" + : $"{name} SQL:", SnapshotValueFormatters.PlainText); } @@ -79,7 +93,9 @@ public static void MatchSnapshot(this Snapshot value) { snapshot.Add( expression, - string.IsNullOrEmpty(name) ? "Expression:" : $"{name} Expression:", + string.IsNullOrEmpty(name) + ? "Expression:" + : $"{name} Expression:", SnapshotValueFormatters.PlainText); } diff --git a/src/HotChocolate/Core/src/Abstractions/WellKnownContextData.cs b/src/HotChocolate/Core/src/Abstractions/WellKnownContextData.cs index 7fd3a3da902..138e47a0325 100644 --- a/src/HotChocolate/Core/src/Abstractions/WellKnownContextData.cs +++ b/src/HotChocolate/Core/src/Abstractions/WellKnownContextData.cs @@ -1,5 +1,3 @@ -using System.Data; - namespace HotChocolate; /// diff --git a/src/HotChocolate/Core/src/Execution/DependencyInjection/RequestExecutorBuilderExtensions.DataLoader.cs b/src/HotChocolate/Core/src/Execution/DependencyInjection/RequestExecutorBuilderExtensions.DataLoader.cs index a1c09dd8ee8..8ed8feb6e47 100644 --- a/src/HotChocolate/Core/src/Execution/DependencyInjection/RequestExecutorBuilderExtensions.DataLoader.cs +++ b/src/HotChocolate/Core/src/Execution/DependencyInjection/RequestExecutorBuilderExtensions.DataLoader.cs @@ -3,7 +3,6 @@ using HotChocolate.Execution.Configuration; using Microsoft.Extensions.DependencyInjection.Extensions; -// ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection; public static partial class RequestExecutorBuilderExtensions diff --git a/src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Global.cs b/src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Global.cs index 290e190c72c..2fa1ffd752e 100644 --- a/src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Global.cs +++ b/src/HotChocolate/Core/src/Execution/Processing/MiddlewareContext.Global.cs @@ -27,6 +27,8 @@ public IServiceProvider Services set => _services = value ?? throw new ArgumentNullException(nameof(value)); } + public IServiceProvider RequestServices => _operationContext.Services; + public ISchema Schema => _operationContext.Schema; public IOperation Operation => _operationContext.Operation; diff --git a/src/HotChocolate/Core/src/Fetching/DefaultDataLoaderRegistry.cs b/src/HotChocolate/Core/src/Fetching/DefaultDataLoaderRegistry.cs index 6964507a75d..93860ee8f9e 100644 --- a/src/HotChocolate/Core/src/Fetching/DefaultDataLoaderRegistry.cs +++ b/src/HotChocolate/Core/src/Fetching/DefaultDataLoaderRegistry.cs @@ -8,14 +8,14 @@ namespace HotChocolate.Fetching; -public class DefaultDataLoaderRegistry : IDataLoaderRegistry +public sealed class DefaultDataLoaderRegistry : IDataLoaderRegistry { private readonly ConcurrentDictionary _dataLoaders = new(); private bool _disposed; public T GetOrRegister(string key, Func createDataLoader) where T : IDataLoader { - if (_dataLoaders.GetOrAdd(key, s => createDataLoader()) is T dataLoader) + if (_dataLoaders.GetOrAdd(key, _ => createDataLoader()) is T dataLoader) { return dataLoader; } @@ -27,8 +27,8 @@ public class DefaultDataLoaderRegistry : IDataLoaderRegistry typeof(T).FullName)); } - public T GetOrRegister(Func createDataLoader) where T : IDataLoader => - GetOrRegister(typeof(T).FullName ?? typeof(T).Name, createDataLoader); + public T GetOrRegister(Func createDataLoader) where T : IDataLoader + => GetOrRegister(typeof(T).FullName ?? typeof(T).Name, createDataLoader); public void Dispose() { diff --git a/src/HotChocolate/Core/src/Fetching/Extensions/DataLoaderResolverContextExtensions.cs b/src/HotChocolate/Core/src/Fetching/Extensions/DataLoaderResolverContextExtensions.cs index 564f4b469b7..acabfa11e68 100644 --- a/src/HotChocolate/Core/src/Fetching/Extensions/DataLoaderResolverContextExtensions.cs +++ b/src/HotChocolate/Core/src/Fetching/Extensions/DataLoaderResolverContextExtensions.cs @@ -75,7 +75,7 @@ public static class DataLoaderResolverContextExtensions throw new ArgumentNullException(nameof(fetch)); } - var services = context.Services; + var services = context.RequestServices; var reg = services.GetRequiredService(); FetchBatchDataLoader Loader() => new( @@ -181,7 +181,7 @@ public static class DataLoaderResolverContextExtensions throw new ArgumentNullException(nameof(fetch)); } - var services = context.Services; + var services = context.RequestServices; var reg = services.GetRequiredService(); FetchGroupedDataLoader Loader() => new( @@ -272,7 +272,7 @@ public static class DataLoaderResolverContextExtensions throw new ArgumentNullException(nameof(fetch)); } - var services = context.Services; + var services = context.RequestServices; var reg = services.GetRequiredService(); FetchCacheDataLoader Loader() => new( @@ -360,7 +360,7 @@ public static T DataLoader(this IResolverContext context, string key) throw new ArgumentNullException(nameof(key)); } - var services = context.Services; + var services = context.RequestServices; var reg = services.GetRequiredService(); return reg.GetOrRegister(key, () => CreateDataLoader(services)); } @@ -374,7 +374,7 @@ public static T DataLoader(this IResolverContext context) throw new ArgumentNullException(nameof(context)); } - var services = context.Services; + var services = context.RequestServices; var reg = services.GetRequiredService(); return reg.GetOrRegister(() => CreateDataLoader(services)); } diff --git a/src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceHelper.cs b/src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceHelper.cs index 563e501852f..6a3f3b7a8fa 100644 --- a/src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceHelper.cs +++ b/src/HotChocolate/Core/src/Types/Resolvers/Expressions/Parameters/ServiceHelper.cs @@ -43,7 +43,8 @@ internal static class ServiceHelper FieldMiddlewareDefinition serviceMiddleware = new(next => async context => { - var objectPool = context.Services.GetRequiredService>(); + var services = context.RequestServices; + var objectPool = services.GetRequiredService>(); var service = objectPool.Get(); context.RegisterForCleanup(() => @@ -84,7 +85,8 @@ internal static class ServiceHelper middleware = new FieldMiddlewareDefinition( next => async context => { - var scope = context.Services.CreateScope(); + var service = context.RequestServices; + var scope = service.CreateScope(); context.RegisterForCleanup(() => { scope.Dispose(); @@ -119,7 +121,7 @@ internal static class ServiceHelper await next(context).ConfigureAwait(false); }, isRepeatable: true, - key: WellKnownMiddleware.PooledService); + key: WellKnownMiddleware.ResolverService); definition.MiddlewareDefinitions.Insert(index + 1, serviceMiddleware); } } diff --git a/src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs b/src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs index ac1a594df5d..13bc94c8da3 100644 --- a/src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs +++ b/src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs @@ -16,10 +16,18 @@ namespace HotChocolate.Resolvers; public interface IResolverContext : IPureResolverContext { /// - /// Gets the scoped request service provider. + /// Gets the resolver service provider. + /// By default the resolver service provider is scoped to the request, + /// but middleware can create a resolver scope. /// IServiceProvider Services { get; set; } + /// + /// Gets the request scoped service provider. + /// We preserve here the access to the original service provider of the request. + /// + IServiceProvider RequestServices { get; } + /// /// Gets the name that the field will have in the response map. /// diff --git a/src/HotChocolate/Core/src/Types/Types/DirectiveCollection.cs b/src/HotChocolate/Core/src/Types/Types/DirectiveCollection.cs index ae468441b13..62118098f74 100644 --- a/src/HotChocolate/Core/src/Types/Types/DirectiveCollection.cs +++ b/src/HotChocolate/Core/src/Types/Types/DirectiveCollection.cs @@ -68,11 +68,9 @@ private static IEnumerable FindDirectives(Directive[] directives, str while (Unsafe.IsAddressLessThan(ref start, ref end)) { - var directive = Unsafe.Add(ref start, 0); - - if (directive.Type.Name.EqualsOrdinal(directiveName)) + if (start.Type.Name.EqualsOrdinal(directiveName)) { - return directive; + return start; } // move pointer diff --git a/src/HotChocolate/Core/test/Types.CursorPagination.Tests/QueryableCursorPagingProviderTests.cs b/src/HotChocolate/Core/test/Types.CursorPagination.Tests/QueryableCursorPagingProviderTests.cs index 6204f139690..8b957b599ee 100644 --- a/src/HotChocolate/Core/test/Types.CursorPagination.Tests/QueryableCursorPagingProviderTests.cs +++ b/src/HotChocolate/Core/test/Types.CursorPagination.Tests/QueryableCursorPagingProviderTests.cs @@ -433,6 +433,8 @@ public IServiceProvider Services set => throw new NotImplementedException(); } + public IServiceProvider RequestServices => throw new NotImplementedException(); + public string ResponseName => throw new NotImplementedException(); public bool HasErrors => throw new NotImplementedException(); diff --git a/src/HotChocolate/Core/test/Types.Tests/Types/Attributes/UseServiceScopeAttributeTests.cs b/src/HotChocolate/Core/test/Types.Tests/Types/Attributes/UseServiceScopeAttributeTests.cs index 6e4b20e002b..199a96f84da 100644 --- a/src/HotChocolate/Core/test/Types.Tests/Types/Attributes/UseServiceScopeAttributeTests.cs +++ b/src/HotChocolate/Core/test/Types.Tests/Types/Attributes/UseServiceScopeAttributeTests.cs @@ -1,9 +1,12 @@ using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; using System.Threading.Tasks; +using CookieCrumble; using HotChocolate.Execution; +using HotChocolate.Resolvers; using Microsoft.Extensions.DependencyInjection; -using Snapshooter.Xunit; -using Xunit; namespace HotChocolate.Types; @@ -13,8 +16,6 @@ public class UseServiceScopeAttributeTests public async Task UseServiceScope() { // arrange - Snapshot.FullName(); - // assert var result = await new ServiceCollection() .AddScoped() @@ -23,8 +24,7 @@ public async Task UseServiceScope() .ExecuteRequestAsync("{ a: scoped b: scoped }"); // assert - Assert.Null(result.ExpectQueryResult().Errors); - var queryResult = Assert.IsAssignableFrom(result); + var queryResult = result.ExpectQueryResult(); Assert.NotEqual(queryResult.Data!["a"], queryResult.Data!["b"]); } @@ -33,11 +33,47 @@ public void UseServiceScope_FieldDescriptor() => Assert.Throws( () => default(IObjectFieldDescriptor).UseServiceScope()); + [Fact] + public async Task UseServiceScope_With_DataLoader() + { + // arrange + using var cts = new CancellationTokenSource(5000); + + // assert + var result = await new ServiceCollection() + .AddScoped() + .AddGraphQL() + .AddQueryType() + .ExecuteRequestAsync("{ scopeWithDataLoader }", cancellationToken: cts.Token); + + // assert + result.MatchInlineSnapshot( + """ + { + "data": { + "scopeWithDataLoader": "abc" + } + } + """); + + cts.Cancel(); + } + public class Query { [UseServiceScope] public string GetScoped([Service] Service service) => service.Id; + + [UseServiceScope] + public Task ScopeWithDataLoader(IResolverContext context, CancellationToken ct) + { + var dataLoader = + context.BatchDataLoader( + (keys, _) => Task.FromResult>( + keys.ToDictionary(a => a, a => a))); + return dataLoader.LoadAsync("abc", ct); + } } public class Service diff --git a/src/HotChocolate/Data/src/Data/Projections/Extensions/ProjectionObjectFieldDescriptorExtensions.cs b/src/HotChocolate/Data/src/Data/Projections/Extensions/ProjectionObjectFieldDescriptorExtensions.cs index 3ae40d3f589..df46caf93c9 100644 --- a/src/HotChocolate/Data/src/Data/Projections/Extensions/ProjectionObjectFieldDescriptorExtensions.cs +++ b/src/HotChocolate/Data/src/Data/Projections/Extensions/ProjectionObjectFieldDescriptorExtensions.cs @@ -248,6 +248,8 @@ private sealed class MiddlewareContextProxy : IMiddlewareContext IReadOnlyDictionary IPureResolverContext.ScopedContextData => ScopedContextData; + public IServiceProvider RequestServices => _context.RequestServices; + public string ResponseName => _context.ResponseName; public bool HasErrors => _context.HasErrors; diff --git a/src/HotChocolate/Data/src/EntityFramework/DbContextParameterExpressionBuilder.cs b/src/HotChocolate/Data/src/EntityFramework/DbContextParameterExpressionBuilder.cs index e1364e8001d..0e2ff4c4496 100644 --- a/src/HotChocolate/Data/src/EntityFramework/DbContextParameterExpressionBuilder.cs +++ b/src/HotChocolate/Data/src/EntityFramework/DbContextParameterExpressionBuilder.cs @@ -57,7 +57,7 @@ public void ApplyConfiguration(ParameterInfo parameter, ObjectFieldDescriptor de _ => _ => throw new NotSupportedException(), key: ToList); var serviceMiddleware = - definition.MiddlewareDefinitions.Last(t => t.Key == PooledService); + definition.MiddlewareDefinitions.Last(t => t.Key == ResolverService); var index = definition.MiddlewareDefinitions.IndexOf(serviceMiddleware) + 1; definition.MiddlewareDefinitions.Insert(index, placeholderMiddleware); AddCompletionMiddleware(definition, placeholderMiddleware); diff --git a/src/HotChocolate/Data/src/EntityFramework/Extensions/EntityFrameworkObjectFieldDescriptorExtensions.cs b/src/HotChocolate/Data/src/EntityFramework/Extensions/EntityFrameworkObjectFieldDescriptorExtensions.cs index 054fae05685..3a9efa1fc13 100644 --- a/src/HotChocolate/Data/src/EntityFramework/Extensions/EntityFrameworkObjectFieldDescriptorExtensions.cs +++ b/src/HotChocolate/Data/src/EntityFramework/Extensions/EntityFrameworkObjectFieldDescriptorExtensions.cs @@ -27,12 +27,12 @@ public static class EntityFrameworkObjectFieldDescriptorExtensions new(next => async context => { #if NET6_0_OR_GREATER - await using TDbContext dbContext = await context.Services + await using var dbContext = await context.RequestServices .GetRequiredService>() .CreateDbContextAsync() .ConfigureAwait(false); #else - using TDbContext dbContext = context.Services + using TDbContext dbContext = context.RequestServices .GetRequiredService>() .CreateDbContext(); #endif @@ -69,7 +69,7 @@ public static class EntityFrameworkObjectFieldDescriptorExtensions FieldMiddlewareDefinition contextMiddleware = new(next => async context => { - var dbContext = await context.Services + var dbContext = await context.RequestServices .GetRequiredService>() .CreateDbContextAsync() .ConfigureAwait(false);