Skip to content

Commit

Permalink
Allow for refetching on mutation payloads. (#2851)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Jan 11, 2021
1 parent 5d9c2df commit 4e19dd5
Show file tree
Hide file tree
Showing 40 changed files with 614 additions and 113 deletions.
@@ -1,5 +1,6 @@
using HotChocolate;
using HotChocolate.Execution.Configuration;
using HotChocolate.Types.Relay;

namespace Microsoft.Extensions.DependencyInjection
{
Expand All @@ -11,12 +12,16 @@ public static partial class SchemaRequestExecutorBuilderExtensions
/// <param name="builder">
/// The <see cref="IRequestExecutorBuilder"/>.
/// </param>
/// <param name="options">
/// The relay schema options.
/// </param>
/// <returns>
/// An <see cref="IRequestExecutorBuilder"/> that can be used to configure a schema
/// and its execution.
/// </returns>
public static IRequestExecutorBuilder EnableRelaySupport(
this IRequestExecutorBuilder builder) =>
builder.ConfigureSchema(c => c.EnableRelaySupport());
this IRequestExecutorBuilder builder,
RelayOptions? options = null) =>
builder.ConfigureSchema(c => c.EnableRelaySupport(options));
}
}
Expand Up @@ -78,7 +78,7 @@ internal sealed class OperationExecutionMiddleware
if (operation.Definition.Operation == OperationType.Subscription)
{
context.Result = await _subscriptionExecutor
.ExecuteAsync(context)
.ExecuteAsync(context, () => GetQueryRootValue(context))
.ConfigureAwait(false);

await _next(context).ConfigureAwait(false);
Expand Down Expand Up @@ -137,39 +137,33 @@ internal sealed class OperationExecutionMiddleware

if (operation.Definition.Operation == OperationType.Query)
{
object? query = RootValueResolver.Resolve(
context,
context.Services,
operation.RootType,
ref _cachedQueryValue);
object? query = GetQueryRootValue(context);

operationContext.Initialize(
context,
context.Services,
batchDispatcher,
operation,
context.Variables!,
query,
context.Variables!);
() => query);

result = await _queryExecutor
.ExecuteAsync(operationContext)
.ConfigureAwait(false);
}
else if (operation.Definition.Operation == OperationType.Mutation)
{
object? mutation = RootValueResolver.Resolve(
context,
context.Services,
operation.RootType,
ref _cachedMutation);
object? mutation = GetMutationRootValue(context);

operationContext.Initialize(
context,
context.Services,
batchDispatcher,
operation,
context.Variables!,
mutation,
context.Variables!);
() => GetQueryRootValue(context));

result = await _mutationExecutor
.ExecuteAsync(operationContext)
Expand All @@ -179,6 +173,20 @@ internal sealed class OperationExecutionMiddleware
return result;
}

private object? GetQueryRootValue(IRequestContext context) =>
RootValueResolver.Resolve(
context,
context.Services,
context.Schema.QueryType,
ref _cachedQueryValue);

private object? GetMutationRootValue(IRequestContext context) =>
RootValueResolver.Resolve(
context,
context.Services,
context.Schema.MutationType!,
ref _cachedMutation);

private bool IsOperationAllowed(IRequestContext context)
{
if (context.Request.AllowedOperations is null or { Length: 0 })
Expand Down
Expand Up @@ -145,5 +145,7 @@ public DirectiveContext(IMiddlewareContext middlewareContext, IDirective directi
SelectionSetNode? selectionSet = null,
bool allowInternals = false) =>
_middlewareContext.GetSelections(typeContext, selectionSet, allowInternals);

public T GetQueryRoot<T>() => _middlewareContext.GetQueryRoot<T>();
}
}
Expand Up @@ -7,6 +7,9 @@

namespace HotChocolate.Execution.Processing
{
/// <summary>
/// The internal operation execution context.
/// </summary>
internal interface IOperationContext : IHasContextData
{
/// <summary>
Expand Down Expand Up @@ -63,11 +66,27 @@ internal interface IOperationContext : IHasContextData
/// <value></value>
ITypeConverter Converter { get; }

/// <summary>
/// The result helper which provides utilities to build up the result.
/// </summary>
IResultHelper Result { get; }

/// <summary>
/// The execution context proved the processing state.
/// </summary>
IExecutionContext Execution { get; }

// TODO : documentation -> remember this are the raw collected fields without visibility
/// <summary>
/// Get the fields for the specified selection set according to the execution plan.
/// The selection set will show all possibilities and needs to be pre-processed.
/// </summary>
/// <param name="selectionSet">
/// The selection set syntax for which we want to get the compiled selection set.
/// </param>
/// <param name="typeContext">
/// The type context.
/// </param>
/// <returns></returns>
ISelectionSet CollectFields(
SelectionSetNode selectionSet,
ObjectType typeContext);
Expand All @@ -79,5 +98,16 @@ internal interface IOperationContext : IHasContextData
/// Cleanup action.
/// </param>
void RegisterForCleanup(Action action);

/// <summary>
/// Get the query root instance.
/// </summary>
/// <typeparam name="T">
/// The type of the query root.
/// </typeparam>
/// <returns>
/// Returns the query root instance.
/// </returns>
T GetQueryRoot<T>();
}
}
Expand Up @@ -158,5 +158,8 @@ public object Service(Type service)

public void RegisterForCleanup(Action action) =>
_operationContext.RegisterForCleanup(action);

public T GetQueryRoot<T>() =>
_operationContext.GetQueryRoot<T>();
}
}
Expand Up @@ -15,6 +15,7 @@ internal sealed partial class OperationContext
private IPreparedOperation _operation = default!;
private IVariableValueCollection _variables = default!;
private IServiceProvider _services = default!;
private Func<object> _resolveQueryRootValue = default!;
private object? _rootValue;
private bool _isPooled = true;

Expand All @@ -33,8 +34,9 @@ internal sealed partial class OperationContext
IServiceProvider scopedServices,
IBatchDispatcher batchDispatcher,
IPreparedOperation operation,
IVariableValueCollection variables,
object? rootValue,
IVariableValueCollection variables)
Func<object> resolveQueryRootValue)
{
_requestContext = requestContext;
_executionContext.Initialize(
Expand All @@ -44,6 +46,7 @@ internal sealed partial class OperationContext
_variables = variables;
_services = scopedServices;
_rootValue = rootValue;
_resolveQueryRootValue = resolveQueryRootValue;
_isPooled = false;
}

Expand All @@ -61,6 +64,7 @@ public void Clean()
_variables = default!;
_services = default!;
_rootValue = null;
_resolveQueryRootValue = default!;
_isPooled = true;
}

Expand Down
26 changes: 26 additions & 0 deletions src/HotChocolate/Core/src/Execution/Processing/OperationContext.cs
@@ -1,4 +1,6 @@
using System;
using System.Security.Cryptography;
using HotChocolate.Execution.Properties;
using HotChocolate.Language;
using HotChocolate.Types;

Expand Down Expand Up @@ -78,5 +80,29 @@ public void RegisterForCleanup(Action action)
AssertNotPooled();
_cleanupActions.Add(action);
}

public T GetQueryRoot<T>()
{
AssertNotPooled();

object? query = _resolveQueryRootValue();

if (query is null &&
typeof(T) == typeof(object) &&
new object() is T dummy)
{
return dummy;
}

if (query is T casted)
{
return casted;
}

throw new InvalidCastException(
string.Format(
Resources.OperationContext_GetQueryRoot_InvalidCast,
typeof(T).FullName ?? typeof(T).Name));
}
}
}
Expand Up @@ -10,17 +10,28 @@ internal static class ResolverExecutionHelper
IOperationContext operationContext)
{
var proposedTaskCount = operationContext.Operation.ProposedTaskCount;
var tasks = new Task[proposedTaskCount];

for (var i = 0; i < proposedTaskCount; i++)
if (proposedTaskCount == 1)
{
tasks[i] = StartExecutionTaskAsync(
await StartExecutionTaskAsync(
operationContext.Execution,
HandleError,
operationContext.RequestAborted);
}
else
{
var tasks = new Task[proposedTaskCount];

await Task.WhenAll(tasks).ConfigureAwait(false);
for (var i = 0; i < proposedTaskCount; i++)
{
tasks[i] = StartExecutionTaskAsync(
operationContext.Execution,
HandleError,
operationContext.RequestAborted);
}

await Task.WhenAll(tasks).ConfigureAwait(false);
}

void HandleError(Exception exception)
{
Expand Down
Expand Up @@ -12,17 +12,17 @@ namespace HotChocolate.Execution.Processing
{
internal sealed partial class SubscriptionExecutor
{
private sealed class Subscription
: IAsyncDisposable
private sealed class Subscription : IAsyncDisposable
{
private readonly ObjectPool<OperationContext> _operationContextPool;
private readonly QueryExecutor _queryExecutor;
private readonly IDiagnosticEvents _diagnosticEvents;
private readonly IRequestContext _requestContext;
private readonly ObjectType _subscriptionType;
private readonly ISelectionSet _rootSelections;
private readonly Func<object> _resolveQueryRootValue;
private ISourceStream _sourceStream = default!;
private object? _cachedRootValue = null;
private object? _cachedRootValue;
private bool _disposed;

private Subscription(
Expand All @@ -31,13 +31,15 @@ private sealed class Subscription
IRequestContext requestContext,
ObjectType subscriptionType,
ISelectionSet rootSelections,
Func<object> resolveQueryRootValue,
IDiagnosticEvents diagnosticEvents)
{
_operationContextPool = operationContextPool;
_queryExecutor = queryExecutor;
_requestContext = requestContext;
_subscriptionType = subscriptionType;
_rootSelections = rootSelections;
_resolveQueryRootValue = resolveQueryRootValue;
_diagnosticEvents = diagnosticEvents;
}

Expand All @@ -47,15 +49,17 @@ private sealed class Subscription
IRequestContext requestContext,
ObjectType subscriptionType,
ISelectionSet rootSelections,
IDiagnosticEvents diagnosicEvents)
Func<object> resolveQueryRootValue,
IDiagnosticEvents diagnosticsEvents)
{
var subscription = new Subscription(
operationContextPool,
queryExecutor,
requestContext,
subscriptionType,
rootSelections,
diagnosicEvents);
resolveQueryRootValue,
diagnosticsEvents);

subscription._sourceStream = await subscription
.SubscribeAsync()
Expand Down Expand Up @@ -116,8 +120,9 @@ private async Task<IQueryResult> OnEvent(object payload)
eventServices,
dispatcher,
_requestContext.Operation!,
_requestContext.Variables!,
rootValue,
_requestContext.Variables!);
_resolveQueryRootValue);

return await _queryExecutor
.ExecuteAsync(operationContext, scopedContext)
Expand Down Expand Up @@ -147,15 +152,16 @@ private async ValueTask<ISourceStream> SubscribeAsync()

// next we need to initialize our operation context so that we have access to
// variables services and other things.
// The subscribe resolver will use a noop dispatcher and all DataLoader are
// The subscribe resolver will use a noop dispatcher and all DataLoader are
// dispatched immediately.
operationContext.Initialize(
_requestContext,
_requestContext.Services,
NoopBatchDispatcher.Default,
_requestContext.Operation!,
_requestContext.Variables!,
rootValue,
_requestContext.Variables!);
_resolveQueryRootValue);

// next we need a result map so that we can store the subscribe temporarily
// while executing the subscribe pipeline.
Expand Down
Expand Up @@ -23,7 +23,8 @@ internal sealed partial class SubscriptionExecutor
}

public async Task<IExecutionResult> ExecuteAsync(
IRequestContext requestContext)
IRequestContext requestContext,
Func<object> resolveQueryValue)
{
if (requestContext is null)
{
Expand Down Expand Up @@ -57,6 +58,7 @@ internal sealed partial class SubscriptionExecutor
requestContext,
requestContext.Operation.RootType,
selectionSet,
resolveQueryValue,
_diagnosticEvents)
.ConfigureAwait(false);

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4e19dd5

Please sign in to comment.