Skip to content

Commit

Permalink
Reworked Stitching Variable Handling (#2533)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Nov 4, 2020
1 parent 1de7cf8 commit 115f47c
Show file tree
Hide file tree
Showing 16 changed files with 53 additions and 133 deletions.
Expand Up @@ -2,7 +2,6 @@
using System.Linq;
using HotChocolate.Language;
using HotChocolate.Resolvers;
using HotChocolate.Stitching.Properties;
using HotChocolate.Types;
using HotChocolate.Utilities;
using static HotChocolate.Stitching.ThrowHelper;
Expand Down
Expand Up @@ -177,7 +177,7 @@ public VariableNode ToVariableNode()

public string ToVariableName()
{
return Scope.Value + "_" + Name.Value;
return "__" + Scope.Value + "_" + Name.Value;
}
}
}
Expand Up @@ -13,12 +13,11 @@ namespace HotChocolate.Stitching.Requests
internal class MergeRequestHelper
{
public static IEnumerable<(IQueryRequest, IEnumerable<BufferedRequest>)> MergeRequests(
IEnumerable<BufferedRequest> requests,
ISet<string> requestVariableNames)
IEnumerable<BufferedRequest> requests)
{
foreach (var group in requests.GroupBy(t => t.Operation.Operation))
{
var rewriter = new MergeRequestRewriter(requestVariableNames);
var rewriter = new MergeRequestRewriter();
var variableValues = new Dictionary<string, object?>();

var operationName = group
Expand All @@ -36,7 +35,7 @@ internal class MergeRequestHelper
BufferedRequest first = null!;
foreach (BufferedRequest request in group)
{
first = request;
first ??= request;
MergeRequest(request, rewriter, variableValues, $"__{i++}_");
}

Expand Down
Expand Up @@ -14,24 +14,13 @@ internal class MergeRequestRewriter : QuerySyntaxRewriter<bool>
new Dictionary<string, VariableDefinitionNode>();
private readonly Dictionary<string, FragmentDefinitionNode> _fragments =
new Dictionary<string, FragmentDefinitionNode>();
private readonly ISet<string> _globalVariableNames;

private Dictionary<string, string>? _aliases;
private NameString _requestPrefix;
private bool _rewriteFragments;
private OperationType? _operationType;
private NameNode? _operationName;

public MergeRequestRewriter(ISet<string> globalVariableNames)
{
_globalVariableNames = globalVariableNames ??
throw new ArgumentNullException(nameof(globalVariableNames));
}

private bool IsAutoGenerated => !_rewriteFragments;

public NameNode OperationName => _operationName ?? _defaultName;

public void SetOperationName(NameNode name) => _operationName = name;

public IDictionary<string, string> AddQuery(
Expand All @@ -50,8 +39,7 @@ public MergeRequestRewriter(ISet<string> globalVariableNames)
OperationDefinitionNode operation =
BufferedRequest.ResolveOperation(rewritten, request.Request.OperationName);

foreach (VariableDefinitionNode variable in
operation.VariableDefinitions)
foreach (VariableDefinitionNode variable in operation.VariableDefinitions)
{
if (!_variables.ContainsKey(variable.Variable.Name.Value))
{
Expand Down Expand Up @@ -94,14 +82,9 @@ public DocumentNode Merge()
}

protected override VariableDefinitionNode RewriteVariableDefinition(
VariableDefinitionNode node, bool context)
{
return IsAutoGenerated
&& _globalVariableNames.Contains(node.Variable.Name.Value)
? node
: node.WithVariable(node.Variable.WithName(
node.Variable.Name.CreateNewName(_requestPrefix)));
}
VariableDefinitionNode node, bool context) =>
node.WithVariable(node.Variable.WithName(
node.Variable.Name.CreateNewName(_requestPrefix)));

protected override FieldNode RewriteField(FieldNode node, bool first)
{
Expand Down Expand Up @@ -129,30 +112,21 @@ protected override FieldNode RewriteField(FieldNode node, bool first)
}

protected override FragmentSpreadNode RewriteFragmentSpread(
FragmentSpreadNode node, bool first)
{
return _rewriteFragments
FragmentSpreadNode node, bool first) =>
_rewriteFragments
? node.WithName(node.Name.CreateNewName(_requestPrefix))
: node;
}

protected override FragmentDefinitionNode RewriteFragmentDefinition(
FragmentDefinitionNode node, bool first)
{
return _rewriteFragments
FragmentDefinitionNode node, bool first) =>
_rewriteFragments
? base.RewriteFragmentDefinition(
node.WithName(node.Name.CreateNewName(_requestPrefix)),
false)
: base.RewriteFragmentDefinition(node, false);
}

protected override VariableNode RewriteVariable(
VariableNode node, bool first)
{
return IsAutoGenerated
&& _globalVariableNames.Contains(node.Name.Value)
? node
: node.WithName(node.Name.CreateNewName(_requestPrefix));
}
VariableNode node, bool first) =>
node.WithName(node.Name.CreateNewName(_requestPrefix));
}
}
@@ -1,37 +1,30 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using GreenDonut;
using HotChocolate.Execution;
using static HotChocolate.Stitching.WellKnownContextData;

namespace HotChocolate.Stitching.Requests
{
internal sealed class RemoteRequestExecutor
: IRemoteRequestExecutor
, IDisposable
{
private static readonly HashSet<string> _empty = new HashSet<string>();
private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1);
private readonly List<BufferedRequest> _bufferedRequests = new List<BufferedRequest>();
private readonly IBatchScheduler _batchScheduler;
private readonly IRequestExecutor _executor;
private readonly IRequestContextAccessor _requestContextAccessor;
private bool _taskRegistered;

public RemoteRequestExecutor(
IBatchScheduler batchScheduler,
IRequestExecutor executor,
IRequestContextAccessor requestContextAccessor)
IRequestExecutor executor)
{
_batchScheduler = batchScheduler ??
throw new ArgumentNullException(nameof(batchScheduler));
_executor = executor ??
throw new ArgumentNullException(nameof(executor));
_requestContextAccessor = requestContextAccessor ??
throw new ArgumentNullException(nameof(requestContextAccessor));
}

/// <iniheritdoc />
Expand Down Expand Up @@ -122,7 +115,7 @@ private async ValueTask ExecuteRequestsInternal(CancellationToken cancellationTo
// we however have to group requests by operation type. This means we should
// end up with one or two requests (query and mutation).
foreach ((IQueryRequest Merged, IEnumerable<BufferedRequest> Requests) batch in
MergeRequestHelper.MergeRequests(_bufferedRequests, GetRequestVariableNames()))
MergeRequestHelper.MergeRequests(_bufferedRequests))
{
// now we take this merged request and run it against the executor.
IExecutionResult result = await _executor
Expand All @@ -148,32 +141,6 @@ private async ValueTask ExecuteRequestsInternal(CancellationToken cancellationTo
}
}

private ISet<string> GetRequestVariableNames()
{
IRequestContext context = _requestContextAccessor.RequestContext;

if (context.ContextData is ConcurrentDictionary<string, object?> optimized)
{
return (ISet<string>)optimized.GetOrAdd(RequestVarNames, key =>
{
if (context.Request.VariableValues is { Count: > 0 } variables)
{
return new HashSet<string>(variables.Keys);
}
return _empty;
})!;
}

if(!context.ContextData.TryGetValue(RequestVarNames, out object? value))
{
value = context.Request.VariableValues is { Count: > 0 } variables
? new HashSet<string>(variables.Keys)
: _empty;
}

return (ISet<string>)value!;
}

public void Dispose()
{
_semaphore.Dispose();
Expand Down
Expand Up @@ -33,8 +33,7 @@ public class StitchingContext : IStitchingContext
executor.Key,
new RemoteRequestExecutor(
batchScheduler,
executor.Value,
requestContextAccessor));
executor.Value));
}
}

Expand Down
Expand Up @@ -43,7 +43,7 @@ public void CreateVariableValue()

// assert
Assert.Equal("bar", Assert.IsType<StringValueNode>(value.DefaultValue).Value);
Assert.Equal("arguments_a", value.Name);
Assert.Equal("__arguments_a", value.Name);
Assert.Equal("String", Assert.IsType<NamedTypeNode>(value.Type).Name.Value);
Assert.Equal("baz", value.Value.Value);
}
Expand Down
Expand Up @@ -43,7 +43,7 @@ public void CreateVariableValue()

// assert
Assert.Null(value.DefaultValue);
Assert.Equal("contextData_a", value.Name);
Assert.Equal("__contextData_a", value.Name);
Assert.Equal("String", Assert.IsType<NamedTypeNode>(value.Type).Name.Value);
Assert.Equal("AbcDef", value.Value.Value);
}
Expand Down Expand Up @@ -79,7 +79,7 @@ public void ContextDataEntryDoesNotExist()

// assert
Assert.Null(value.DefaultValue);
Assert.Equal("contextData_a", value.Name);
Assert.Equal("__contextData_a", value.Name);
Assert.Equal("String", Assert.IsType<NamedTypeNode>(value.Type).Name.Value);
Assert.Equal(NullValueNode.Default, value.Value);
}
Expand Down
Expand Up @@ -45,7 +45,7 @@ public void CreateVariableValue()

// assert
Assert.Null(value.DefaultValue);
Assert.Equal("fields_a", value.Name);
Assert.Equal("__fields_a", value.Name);
Assert.IsType<NamedTypeNode>(value.Type);
Assert.Equal("baz", value.Value.Value);
}
Expand Down
Expand Up @@ -38,13 +38,12 @@ public void BuildRemoteQuery()
.SelectionSet.Selections
.OfType<FieldNode>().Single();


// act
DocumentNode newQuery = RemoteQueryBuilder.New()
.SetOperation(null, OperationType.Query)
.SetSelectionPath(path)
.SetRequestField(field)
.AddVariable("fields_bar", new NamedTypeNode(null, new NameNode("String")))
.AddVariable("__fields_bar", new NamedTypeNode(null, new NameNode("String")))
.Build("abc", new Dictionary<(NameString Type, NameString Schema), NameString>());

// assert
Expand Down Expand Up @@ -88,7 +87,7 @@ public void BuildRemoteQueryCanOverrideOperationName()
OperationType.Query)
.SetSelectionPath(path)
.SetRequestField(field)
.AddVariable("fields_bar", new NamedTypeNode(null, new NameNode("String")))
.AddVariable("__fields_bar", new NamedTypeNode(null, new NameNode("String")))
.Build("abc", new Dictionary<(NameString Type, NameString Schema), NameString>());

// assert
Expand Down
Expand Up @@ -43,9 +43,9 @@ public void CreateVariableValue()

// assert
Assert.Null(value.DefaultValue);
Assert.Equal("scopedContextData_a", value.Name);
Assert.Equal("__scopedContextData_a", value.Name);
Assert.Equal("String", Assert.IsType<NamedTypeNode>(value.Type).Name.Value);
Assert.Equal("AbcDef", value.Value.Value);
Assert.Equal("AbcDef", value.Value!.Value);
}

[Fact]
Expand All @@ -60,7 +60,8 @@ public void ContextDataEntryDoesNotExist()
c.Options.StrictValidation = false;
});

var contextData = ImmutableDictionary<string, object>.Empty;
ImmutableDictionary<string, object> contextData =
ImmutableDictionary<string, object>.Empty;

var context = new Mock<IResolverContext>(MockBehavior.Strict);
context.SetupGet(t => t.ScopedContextData).Returns(contextData);
Expand All @@ -79,7 +80,7 @@ public void ContextDataEntryDoesNotExist()

// assert
Assert.Null(value.DefaultValue);
Assert.Equal("scopedContextData_a", value.Name);
Assert.Equal("__scopedContextData_a", value.Name);
Assert.Equal("String", Assert.IsType<NamedTypeNode>(value.Type).Name.Value);
Assert.Equal(NullValueNode.Default, value.Value);
}
Expand All @@ -104,13 +105,13 @@ public void ContextIsNull()

// act
var resolver = new ScopedContextDataScopedVariableResolver();
Action a = () => resolver.Resolve(
null,
void Action() => resolver.Resolve(
null!,
scopedVariable,
schema.GetType<StringType>("String"));

// assert
Assert.Equal("context", Assert.Throws<ArgumentNullException>(a).ParamName);
Assert.Equal("context", Assert.Throws<ArgumentNullException>(Action).ParamName);
}

[Fact]
Expand All @@ -129,27 +130,19 @@ public void ScopedVariableIsNull()

// act
var resolver = new ScopedContextDataScopedVariableResolver();
Action a = () => resolver.Resolve(
void Action() => resolver.Resolve(
context.Object,
null,
null!,
schema.GetType<StringType>("String"));

// assert
Assert.Equal("variable", Assert.Throws<ArgumentNullException>(a).ParamName);
Assert.Equal("variable", Assert.Throws<ArgumentNullException>(Action).ParamName);
}

[Fact]
public void TargetTypeIsNull()
{
// arrange
var schema = Schema.Create(
"type Query { foo(a: String = \"bar\") : String }",
c =>
{
c.Use(next => context => default);
c.Options.StrictValidation = false;
});

var context = new Mock<IMiddlewareContext>();

var scopedVariable = new ScopedVariableNode(
Expand All @@ -159,13 +152,13 @@ public void TargetTypeIsNull()

// act
var resolver = new ScopedContextDataScopedVariableResolver();
Action a = () => resolver.Resolve(
void Action() => resolver.Resolve(
context.Object,
scopedVariable,
null);
null!);

// assert
Assert.Equal("targetType", Assert.Throws<ArgumentNullException>(a).ParamName);
Assert.Equal("targetType", Assert.Throws<ArgumentNullException>(Action).ParamName);
}

[Fact]
Expand All @@ -189,13 +182,13 @@ public void InvalidScope()

// act
var resolver = new ScopedContextDataScopedVariableResolver();
Action a = () => resolver.Resolve(
void Action() => resolver.Resolve(
context.Object,
scopedVariable,
schema.GetType<StringType>("String"));

// assert
Assert.Equal("variable", Assert.Throws<ArgumentException>(a).ParamName);
Assert.Equal("variable", Assert.Throws<ArgumentException>(Action).ParamName);
}
}
}

0 comments on commit 115f47c

Please sign in to comment.