Skip to content

Commit

Permalink
Fixed ScopedServices handling on resolvers (#5671)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Jan 10, 2023
1 parent 669e84c commit f4886d2
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 34 deletions.
Expand Up @@ -7,6 +7,12 @@ namespace CookieCrumble;

public static class SnapshotExtensions
{
public static void MatchInlineSnapshot(
this object? value,
string snapshot,
ISnapshotValueFormatter? formatter = null)
=> Snapshot.Create().Add(value, formatter: formatter).MatchInline(snapshot);

public static void MatchSnapshot(this Snapshot value)
=> value.Match();

Expand Down Expand Up @@ -54,7 +60,11 @@ public static void MatchSnapshot(this Snapshot value)
return snapshot;
}

snapshot.Add(result.ToJson(), string.IsNullOrEmpty(name) ? "Result:" : $"{name} Result:");
snapshot.Add(
result.ToJson(),
string.IsNullOrEmpty(name)
? "Result:"
: $"{name} Result:");
snapshot.SetPostFix(TestEnvironment.TargetFramework);

if (result.ContextData.TryGetValue("query", out var queryResult) &&
Expand All @@ -63,23 +73,29 @@ public static void MatchSnapshot(this Snapshot value)
{
snapshot.Add(
queryString,
string.IsNullOrEmpty(name) ? "Query:" : $"{name} Query:",
string.IsNullOrEmpty(name)
? "Query:"
: $"{name} Query:",
SnapshotValueFormatters.PlainText);
}

if (result.ContextData.TryGetValue("sql", out var sql))
{
snapshot.Add(
sql,
string.IsNullOrEmpty(name) ? "SQL:" : $"{name} SQL:",
string.IsNullOrEmpty(name)
? "SQL:"
: $"{name} SQL:",
SnapshotValueFormatters.PlainText);
}

if (result.ContextData.TryGetValue("expression", out var expression))
{
snapshot.Add(
expression,
string.IsNullOrEmpty(name) ? "Expression:" : $"{name} Expression:",
string.IsNullOrEmpty(name)
? "Expression:"
: $"{name} Expression:",
SnapshotValueFormatters.PlainText);
}

Expand Down
@@ -1,5 +1,3 @@
using System.Data;

namespace HotChocolate;

/// <summary>
Expand Down
Expand Up @@ -3,7 +3,6 @@
using HotChocolate.Execution.Configuration;
using Microsoft.Extensions.DependencyInjection.Extensions;

// ReSharper disable once CheckNamespace
namespace Microsoft.Extensions.DependencyInjection;

public static partial class RequestExecutorBuilderExtensions
Expand Down
Expand Up @@ -27,6 +27,8 @@ public IServiceProvider Services
set => _services = value ?? throw new ArgumentNullException(nameof(value));
}

public IServiceProvider RequestServices => _operationContext.Services;

public ISchema Schema => _operationContext.Schema;

public IOperation Operation => _operationContext.Operation;
Expand Down
Expand Up @@ -8,14 +8,14 @@

namespace HotChocolate.Fetching;

public class DefaultDataLoaderRegistry : IDataLoaderRegistry
public sealed class DefaultDataLoaderRegistry : IDataLoaderRegistry
{
private readonly ConcurrentDictionary<string, IDataLoader> _dataLoaders = new();
private bool _disposed;

public T GetOrRegister<T>(string key, Func<T> createDataLoader) where T : IDataLoader
{
if (_dataLoaders.GetOrAdd(key, s => createDataLoader()) is T dataLoader)
if (_dataLoaders.GetOrAdd(key, _ => createDataLoader()) is T dataLoader)
{
return dataLoader;
}
Expand All @@ -27,8 +27,8 @@ public class DefaultDataLoaderRegistry : IDataLoaderRegistry
typeof(T).FullName));
}

public T GetOrRegister<T>(Func<T> createDataLoader) where T : IDataLoader =>
GetOrRegister(typeof(T).FullName ?? typeof(T).Name, createDataLoader);
public T GetOrRegister<T>(Func<T> createDataLoader) where T : IDataLoader
=> GetOrRegister(typeof(T).FullName ?? typeof(T).Name, createDataLoader);

public void Dispose()
{
Expand Down
Expand Up @@ -75,7 +75,7 @@ public static class DataLoaderResolverContextExtensions
throw new ArgumentNullException(nameof(fetch));
}

var services = context.Services;
var services = context.RequestServices;
var reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchBatchDataLoader<TKey, TValue> Loader()
=> new(
Expand Down Expand Up @@ -181,7 +181,7 @@ public static class DataLoaderResolverContextExtensions
throw new ArgumentNullException(nameof(fetch));
}

var services = context.Services;
var services = context.RequestServices;
var reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchGroupedDataLoader<TKey, TValue> Loader()
=> new(
Expand Down Expand Up @@ -272,7 +272,7 @@ public static class DataLoaderResolverContextExtensions
throw new ArgumentNullException(nameof(fetch));
}

var services = context.Services;
var services = context.RequestServices;
var reg = services.GetRequiredService<IDataLoaderRegistry>();
FetchCacheDataLoader<TKey, TValue> Loader()
=> new(
Expand Down Expand Up @@ -360,7 +360,7 @@ public static T DataLoader<T>(this IResolverContext context, string key)
throw new ArgumentNullException(nameof(key));
}

var services = context.Services;
var services = context.RequestServices;
var reg = services.GetRequiredService<IDataLoaderRegistry>();
return reg.GetOrRegister(key, () => CreateDataLoader<T>(services));
}
Expand All @@ -374,7 +374,7 @@ public static T DataLoader<T>(this IResolverContext context)
throw new ArgumentNullException(nameof(context));
}

var services = context.Services;
var services = context.RequestServices;
var reg = services.GetRequiredService<IDataLoaderRegistry>();
return reg.GetOrRegister(() => CreateDataLoader<T>(services));
}
Expand Down
Expand Up @@ -43,7 +43,8 @@ internal static class ServiceHelper
FieldMiddlewareDefinition serviceMiddleware =
new(next => async context =>
{
var objectPool = context.Services.GetRequiredService<ObjectPool<TService>>();
var services = context.RequestServices;
var objectPool = services.GetRequiredService<ObjectPool<TService>>();
var service = objectPool.Get();
context.RegisterForCleanup(() =>
Expand Down Expand Up @@ -84,7 +85,8 @@ internal static class ServiceHelper
middleware = new FieldMiddlewareDefinition(
next => async context =>
{
var scope = context.Services.CreateScope();
var service = context.RequestServices;
var scope = service.CreateScope();
context.RegisterForCleanup(() =>
{
scope.Dispose();
Expand Down Expand Up @@ -119,7 +121,7 @@ internal static class ServiceHelper
await next(context).ConfigureAwait(false);
},
isRepeatable: true,
key: WellKnownMiddleware.PooledService);
key: WellKnownMiddleware.ResolverService);
definition.MiddlewareDefinitions.Insert(index + 1, serviceMiddleware);
}
}
10 changes: 9 additions & 1 deletion src/HotChocolate/Core/src/Types/Resolvers/IResolverContext.cs
Expand Up @@ -16,10 +16,18 @@ namespace HotChocolate.Resolvers;
public interface IResolverContext : IPureResolverContext
{
/// <summary>
/// Gets the scoped request service provider.
/// Gets the resolver service provider.
/// By default the resolver service provider is scoped to the request,
/// but middleware can create a resolver scope.
/// </summary>
IServiceProvider Services { get; set; }

/// <summary>
/// Gets the request scoped service provider.
/// We preserve here the access to the original service provider of the request.
/// </summary>
IServiceProvider RequestServices { get; }

/// <summary>
/// Gets the name that the field will have in the response map.
/// </summary>
Expand Down
6 changes: 2 additions & 4 deletions src/HotChocolate/Core/src/Types/Types/DirectiveCollection.cs
Expand Up @@ -68,11 +68,9 @@ private static IEnumerable<Directive> FindDirectives(Directive[] directives, str

while (Unsafe.IsAddressLessThan(ref start, ref end))
{
var directive = Unsafe.Add(ref start, 0);

if (directive.Type.Name.EqualsOrdinal(directiveName))
if (start.Type.Name.EqualsOrdinal(directiveName))
{
return directive;
return start;
}

// move pointer
Expand Down
Expand Up @@ -433,6 +433,8 @@ public IServiceProvider Services
set => throw new NotImplementedException();
}

public IServiceProvider RequestServices => throw new NotImplementedException();

public string ResponseName => throw new NotImplementedException();

public bool HasErrors => throw new NotImplementedException();
Expand Down
@@ -1,9 +1,12 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using CookieCrumble;
using HotChocolate.Execution;
using HotChocolate.Resolvers;
using Microsoft.Extensions.DependencyInjection;
using Snapshooter.Xunit;
using Xunit;

namespace HotChocolate.Types;

Expand All @@ -13,8 +16,6 @@ public class UseServiceScopeAttributeTests
public async Task UseServiceScope()
{
// arrange
Snapshot.FullName();

// assert
var result = await new ServiceCollection()
.AddScoped<Service>()
Expand All @@ -23,8 +24,7 @@ public async Task UseServiceScope()
.ExecuteRequestAsync("{ a: scoped b: scoped }");

// assert
Assert.Null(result.ExpectQueryResult().Errors);
var queryResult = Assert.IsAssignableFrom<IQueryResult>(result);
var queryResult = result.ExpectQueryResult();
Assert.NotEqual(queryResult.Data!["a"], queryResult.Data!["b"]);
}

Expand All @@ -33,11 +33,47 @@ public void UseServiceScope_FieldDescriptor()
=> Assert.Throws<ArgumentNullException>(
() => default(IObjectFieldDescriptor).UseServiceScope());

[Fact]
public async Task UseServiceScope_With_DataLoader()
{
// arrange
using var cts = new CancellationTokenSource(5000);

// assert
var result = await new ServiceCollection()
.AddScoped<Service>()
.AddGraphQL()
.AddQueryType<Query>()
.ExecuteRequestAsync("{ scopeWithDataLoader }", cancellationToken: cts.Token);

// assert
result.MatchInlineSnapshot(
"""
{
"data": {
"scopeWithDataLoader": "abc"
}
}
""");

cts.Cancel();
}

public class Query
{
[UseServiceScope]
public string GetScoped([Service] Service service)
=> service.Id;

[UseServiceScope]
public Task<string> ScopeWithDataLoader(IResolverContext context, CancellationToken ct)
{
var dataLoader =
context.BatchDataLoader<string, string>(
(keys, _) => Task.FromResult<IReadOnlyDictionary<string, string>>(
keys.ToDictionary(a => a, a => a)));
return dataLoader.LoadAsync("abc", ct);
}
}

public class Service
Expand Down
Expand Up @@ -248,6 +248,8 @@ private sealed class MiddlewareContextProxy : IMiddlewareContext
IReadOnlyDictionary<string, object?> IPureResolverContext.ScopedContextData
=> ScopedContextData;

public IServiceProvider RequestServices => _context.RequestServices;

public string ResponseName => _context.ResponseName;

public bool HasErrors => _context.HasErrors;
Expand Down
Expand Up @@ -57,7 +57,7 @@ public void ApplyConfiguration(ParameterInfo parameter, ObjectFieldDescriptor de
_ => _ => throw new NotSupportedException(),
key: ToList);
var serviceMiddleware =
definition.MiddlewareDefinitions.Last(t => t.Key == PooledService);
definition.MiddlewareDefinitions.Last(t => t.Key == ResolverService);
var index = definition.MiddlewareDefinitions.IndexOf(serviceMiddleware) + 1;
definition.MiddlewareDefinitions.Insert(index, placeholderMiddleware);
AddCompletionMiddleware(definition, placeholderMiddleware);
Expand Down
Expand Up @@ -27,12 +27,12 @@ public static class EntityFrameworkObjectFieldDescriptorExtensions
new(next => async context =>
{
#if NET6_0_OR_GREATER
await using TDbContext dbContext = await context.Services
await using var dbContext = await context.RequestServices
.GetRequiredService<IDbContextFactory<TDbContext>>()
.CreateDbContextAsync()
.ConfigureAwait(false);
#else
using TDbContext dbContext = context.Services
using TDbContext dbContext = context.RequestServices
.GetRequiredService<IDbContextFactory<TDbContext>>()
.CreateDbContext();
#endif
Expand Down Expand Up @@ -69,7 +69,7 @@ public static class EntityFrameworkObjectFieldDescriptorExtensions
FieldMiddlewareDefinition contextMiddleware =
new(next => async context =>
{
var dbContext = await context.Services
var dbContext = await context.RequestServices
.GetRequiredService<IDbContextFactory<TDbContext>>()
.CreateDbContextAsync()
.ConfigureAwait(false);
Expand Down

0 comments on commit f4886d2

Please sign in to comment.