Skip to content

Commit

Permalink
Fixed where the stitching query rewriter skipped over variables. (#2600)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Nov 18, 2020
1 parent 9d453e6 commit bec4398
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 61 deletions.
Expand Up @@ -34,18 +34,15 @@ public class Context

public INamedOutputType? TypeContext { get; set; }

public DirectiveType Directive { get; set; }
public IOutputField? OutputField { get; set; }

public IOutputField OutputField { get; set; }
public IInputField? InputField { get; set; }

public IInputField InputField { get; set; }

public IInputType InputType { get; set; }
public IInputType? InputType { get; set; }

public ImmutableHashSet<string> FragmentPath { get; set; }

public IDictionary<string, FragmentDefinitionNode> Fragments
{ get; }
public IDictionary<string, FragmentDefinitionNode> Fragments { get; }

public Context Clone()
{
Expand Down
Expand Up @@ -55,43 +55,53 @@ internal class MergeRequestHelper
IQueryResult mergedResult,
IEnumerable<BufferedRequest> requests)
{
var handledErrors = new HashSet<IError>();
BufferedRequest? current = null;
QueryResultBuilder? resultBuilder = null;

foreach (BufferedRequest request in requests)
try
{
if (current is not null && resultBuilder is not null)
{
current.Promise.SetResult(resultBuilder.Create());
}
var handledErrors = new HashSet<IError>();
BufferedRequest? current = null;
QueryResultBuilder? resultBuilder = null;

try
foreach (BufferedRequest request in requests)
{
current = request;
resultBuilder = ExtractResult(request.Aliases!, mergedResult, handledErrors);
if (current is not null && resultBuilder is not null)
{
current.Promise.SetResult(resultBuilder.Create());
}

try
{
current = request;
resultBuilder = ExtractResult(request.Aliases!, mergedResult, handledErrors);
}
catch (Exception ex)
{
current = null;
resultBuilder = null;
request.Promise.SetException(ex);
}
}
catch (Exception ex)

if (current is not null && resultBuilder is not null)
{
current = null;
resultBuilder = null;
request.Promise.SetException(ex);
if (mergedResult.Errors is not null &&
handledErrors.Count < mergedResult.Errors.Count)
{
foreach (IError error in mergedResult.Errors.Except(handledErrors))
{
resultBuilder.AddError(error);
}
}

handledErrors.Clear();
current.Promise.SetResult(resultBuilder.Create());
}
}

if (current is not null && resultBuilder is not null)
catch (Exception ex)
{
if (mergedResult.Errors is not null &&
handledErrors.Count < mergedResult.Errors.Count)
foreach (BufferedRequest request in requests)
{
foreach (IError error in mergedResult.Errors.Except(handledErrors))
{
resultBuilder.AddError(error);
}
request.Promise.TrySetException(ex);
}

handledErrors.Clear();
current.Promise.SetResult(resultBuilder.Create());
}
}

Expand Down Expand Up @@ -130,30 +140,75 @@ internal class MergeRequestHelper
}
}

// This method extracts the relevant data from a merged result for a specific result.
private static QueryResultBuilder ExtractResult(
IDictionary<string, string> aliases,
IQueryResult mergedResult,
ICollection<IError> handledErrors)
{
var result = QueryResultBuilder.New();

if (mergedResult.Data is not null)
// We first try to identify and copy data segments that belong to our specific result.
ExtractData(aliases, mergedResult, result);

// After extracting the data, we will try to find errors that can be associated with
// our specific request for which we are trying to branch out the result.
ExtractErrors(aliases, mergedResult, handledErrors, result);

// Last but not least we will copy all extensions and contextData over
// to the specific responses.
if (mergedResult.Extensions is not null)
{
var data = new ResultMap();
data.EnsureCapacity(aliases.Count);
var i = 0;
result.SetExtensions(mergedResult.Extensions);
}

if (mergedResult.ContextData is not null)
{
foreach (KeyValuePair<string, object?> item in mergedResult.ContextData)
{
result.SetContextData(item.Key, item.Value);
}
}

return result;
}

private static void ExtractData(
IDictionary<string, string> aliases,
IQueryResult mergedResult,
QueryResultBuilder result)
{
var data = new ResultMap();
data.EnsureCapacity(aliases.Count);
var i = 0;

if (mergedResult.Data is not null)
{
foreach (KeyValuePair<string, string> alias in aliases)
{
if (mergedResult.Data.TryGetValue(alias.Key, out object? o))
{
data.SetValue(i++, alias.Value, o);
}
}

result.SetData(data);
}
else
{
foreach (KeyValuePair<string, string> alias in aliases)
{
data.SetValue(i++, alias.Value, null);
}
}

result.SetData(data);
}

private static void ExtractErrors(
IDictionary<string, string> aliases,
IQueryResult mergedResult,
ICollection<IError> handledErrors,
QueryResultBuilder result)
{
if (mergedResult.Errors is not null)
{
foreach (IError error in mergedResult.Errors)
Expand All @@ -165,21 +220,6 @@ internal class MergeRequestHelper
}
}
}

if (mergedResult.Extensions is not null)
{
result.SetExtensions(mergedResult.Extensions);
}

if (mergedResult.ContextData is not null)
{
foreach (KeyValuePair<string, object?> item in mergedResult.ContextData)
{
result.SetContextData(item.Key, item.Value);
}
}

return result;
}

private static IError RewriteError(IError error, string responseName)
Expand Down
Expand Up @@ -102,6 +102,10 @@ protected override FieldNode RewriteField(FieldNode node, bool first)
(p, c) => RewriteMany(p, c, RewriteArgument),
current.WithArguments);

current = Rewrite(current, node.Directives, first,
(p, c) => RewriteMany(p, c, RewriteDirective),
current.WithDirectives);

if (node.SelectionSet != null)
{
current = Rewrite(current, node.SelectionSet, false,
Expand All @@ -125,6 +129,23 @@ protected override FieldNode RewriteField(FieldNode node, bool first)
false)
: base.RewriteFragmentDefinition(node, false);

protected override DirectiveNode RewriteDirective(
DirectiveNode node, bool first)
{
if (node.Arguments.Count == 0)
{
return node;
}

DirectiveNode current = node;

current = Rewrite(current, current.Arguments, first,
(p, c) => RewriteMany(p, c, RewriteArgument),
current.WithArguments);

return current;
}

protected override VariableNode RewriteVariable(
VariableNode node, bool first) =>
node.WithName(node.Name.CreateNewName(_requestPrefix));
Expand Down
Expand Up @@ -27,13 +27,13 @@ internal sealed class RemoteRequestExecutor
throw new ArgumentNullException(nameof(executor));
}

/// <iniheritdoc />
/// <inheritdoc />
public ISchema Schema => _executor.Schema;

/// <iniheritdoc />
/// <inheritdoc />
public IServiceProvider Services => _executor.Services;

/// <iniheritdoc />
/// <inheritdoc />
public Task<IExecutionResult> ExecuteAsync(
IQueryRequest request,
CancellationToken cancellationToken = default)
Expand Down
Expand Up @@ -48,7 +48,7 @@ public async Task AutoMerge_Schema()
IHttpClientFactory httpClientFactory = CreateDefaultRemoteSchemas(configurationName);

IDatabase database = _connection.GetDatabase();
for (int i = 0; i < 10; i++)
for (var i = 0; i < 10; i++)
{
if (await database.SetLengthAsync(configurationName.Value) == 4)
{
Expand Down Expand Up @@ -142,7 +142,7 @@ public async Task AutoMerge_Execute()
IHttpClientFactory httpClientFactory = CreateDefaultRemoteSchemas(configurationName);

IDatabase database = _connection.GetDatabase();
for (int i = 0; i < 10; i++)
for (var i = 0; i < 10; i++)
{
if (await database.SetLengthAsync(configurationName.Value) == 4)
{
Expand Down Expand Up @@ -187,7 +187,7 @@ public async Task AutoMerge_AddLocal_Field_Execute()
IHttpClientFactory httpClientFactory = CreateDefaultRemoteSchemas(configurationName);

IDatabase database = _connection.GetDatabase();
for (int i = 0; i < 10; i++)
for (var i = 0; i < 10; i++)
{
if (await database.SetLengthAsync(configurationName.Value) == 4)
{
Expand Down
Expand Up @@ -126,6 +126,50 @@ public async Task AutoMerge_AddLocal_Field_Execute()
result.ToJson().MatchSnapshot();
}

[Fact]
public async Task Directive_Variables_Are_Correctly_Rewritten()
{
// arrange
IHttpClientFactory httpClientFactory = CreateDefaultRemoteSchemas();

IRequestExecutor executor =
await new ServiceCollection()
.AddSingleton(httpClientFactory)
.AddGraphQL()
.AddQueryType(d => d.Name("Query").Field("local").Resolve("I am local."))
.AddRemoteSchema(_accounts)
.AddRemoteSchema(_inventory)
.AddRemoteSchema(_products)
.AddRemoteSchema(_reviews)
.BuildRequestExecutorAsync();

// act
IExecutionResult result = await executor.ExecuteAsync(
@"query ($if1: Boolean! $if2: Boolean! $if3: Boolean! $if4: Boolean!) {
me {
id
alias1: name @include(if: $if1)
alias2: reviews @include(if: $if2) {
alias3: body @include(if: $if3)
alias4: product @include(if: $if4) {
upc
}
}
}
local
}",
new Dictionary<string, object>
{
{ "if1", true },
{ "if2", true },
{ "if3", true },
{ "if4", true },
});

// assert
result.ToJson().MatchSnapshot();
}

public TestServer CreateAccountsService() =>
Context.ServerFactory.Create(
services => services
Expand Down
@@ -0,0 +1,23 @@
{
"data": {
"me": {
"id": 1,
"alias1": "Ada Lovelace",
"alias2": [
{
"alias3": "Love it!",
"alias4": {
"upc": 1
}
},
{
"alias3": "Too expensive.",
"alias4": {
"upc": 2
}
}
]
},
"local": "I am local."
}
}

0 comments on commit bec4398

Please sign in to comment.