Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for pattern matching with the IsSelectedAttribute. #7061

Merged
merged 5 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,4 +309,9 @@ public static class WellKnownContextData
/// Type key to access the node id result formatter on the descriptor context.
/// </summary>
public const string NodeIdResultFormatter = "HotChocolate.Relay.NodeIdResultFormatter";

/// <summary>
/// Type key to access the pattern validation tasks.
/// </summary>
public const string PatternValidationTasks = "HotChocolate.Validation.PatternValidationTasks";
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
using System.Collections.Generic;
using System.Linq;
using HotChocolate.Resolvers;
using HotChocolate.Types;

namespace HotChocolate.Execution.Processing;

internal sealed class EmptySelectionCollection : ISelectionCollection
{
private static readonly ISelection[] _empty = Array.Empty<ISelection>();

public static EmptySelectionCollection Instance { get; } = new();

public int Count => 0;
Expand All @@ -19,6 +20,12 @@ internal sealed class EmptySelectionCollection : ISelectionCollection
public ISelectionCollection Select(string fieldName)
=> Instance;

public ISelectionCollection Select(ReadOnlySpan<string> fieldNames)
=> Instance;

public ISelectionCollection Select(INamedType typeContext)
=> Instance;

public bool IsSelected(string fieldName)
=> false;

Expand All @@ -30,10 +37,10 @@ public bool IsSelected(string fieldName1, string fieldName2, string fieldName3)

public bool IsSelected(ISet<string> fieldNames)
=> false;

public IEnumerator<ISelection> GetEnumerator()
=> _empty.AsEnumerable().GetEnumerator();

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,9 @@ internal partial class MiddlewareContext
return selectionSet.Selections;
}

public ISelectionCollection Select()
=> new SelectionCollection(Schema, Operation, [Selection], _operationContext.IncludeFlags);

public ISelectionCollection Select(string fieldName)
=> new SelectionCollection(
Schema,
Operation,
[Selection,],
_operationContext.IncludeFlags)
.Select(fieldName);
}
=> Select().Select(fieldName);
}
208 changes: 154 additions & 54 deletions src/HotChocolate/Core/src/Execution/Processing/SelectionCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ internal sealed class SelectionCollection(
private readonly ISchema _schema = schema ?? throw new ArgumentNullException(nameof(schema));
private readonly IOperation _operation = operation ?? throw new ArgumentNullException(nameof(operation));
private readonly ISelection[] _selections = selections ?? throw new ArgumentNullException(nameof(selections));

public int Count => _selections.Length;

public ISelection this[int index] => _selections[index];
Expand All @@ -40,6 +40,32 @@ public ISelectionCollection Select(string fieldName)
return new SelectionCollection(_schema, _operation, selections, includeFlags);
}

public ISelectionCollection Select(ReadOnlySpan<string> fieldNames)
{
if (!CollectSelections(fieldNames, out var buffer, out var size))
{
return new SelectionCollection(_schema, _operation, Array.Empty<ISelection>(), includeFlags);
}

var selections = new ISelection[size];
buffer.AsSpan().Slice(0, size).CopyTo(selections);
ArrayPool<ISelection>.Shared.Return(buffer);
return new SelectionCollection(_schema, _operation, selections, includeFlags);
}

public ISelectionCollection Select(INamedType typeContext)
{
if (!CollectSelections(typeContext, out var buffer, out var size))
{
return new SelectionCollection(_schema, _operation, Array.Empty<ISelection>(), includeFlags);
}

var selections = new ISelection[size];
buffer.AsSpan().Slice(0, size).CopyTo(selections);
ArrayPool<ISelection>.Shared.Return(buffer);
return new SelectionCollection(_schema, _operation, selections, includeFlags);
}

public bool IsSelected(string fieldName)
{
if (fieldName is null)
Expand All @@ -64,11 +90,11 @@ public bool IsSelected(string fieldName)
foreach (var possibleType in _schema.GetPossibleTypes(namedType))
{
if (IsChildSelected(
_operation,
includeFlags,
possibleType,
start,
fieldName))
_operation,
includeFlags,
possibleType,
start,
fieldName))
{
return true;
}
Expand All @@ -77,11 +103,11 @@ public bool IsSelected(string fieldName)
else
{
if (IsChildSelected(
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName))
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName))
{
return true;
}
Expand Down Expand Up @@ -149,12 +175,12 @@ public bool IsSelected(string fieldName1, string fieldName2)
foreach (var possibleType in _schema.GetPossibleTypes(namedType))
{
if (IsChildSelected(
_operation,
includeFlags,
possibleType,
start,
fieldName1,
fieldName2))
_operation,
includeFlags,
possibleType,
start,
fieldName1,
fieldName2))
{
return true;
}
Expand All @@ -163,12 +189,12 @@ public bool IsSelected(string fieldName1, string fieldName2)
else
{
if (IsChildSelected(
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName1,
fieldName2))
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName1,
fieldName2))
{
return true;
}
Expand Down Expand Up @@ -243,13 +269,13 @@ public bool IsSelected(string fieldName1, string fieldName2, string fieldName3)
foreach (var possibleType in _schema.GetPossibleTypes(namedType))
{
if (IsChildSelected(
_operation,
includeFlags,
possibleType,
start,
fieldName1,
fieldName2,
fieldName3))
_operation,
includeFlags,
possibleType,
start,
fieldName1,
fieldName2,
fieldName3))
{
return true;
}
Expand All @@ -258,13 +284,13 @@ public bool IsSelected(string fieldName1, string fieldName2, string fieldName3)
else
{
if (IsChildSelected(
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName1,
fieldName2,
fieldName3))
_operation,
includeFlags,
(ObjectType)namedType,
start,
fieldName1,
fieldName2,
fieldName3))
{
return true;
}
Expand Down Expand Up @@ -377,7 +403,24 @@ public bool IsSelected(ISet<string> fieldNames)
}
}

private bool CollectSelections(string fieldName, out ISelection[] buffer, out int size)
private bool CollectSelections(
string fieldName,
out ISelection[] buffer,
out int size)
{
var fieldNames = ArrayPool<string>.Shared.Rent(1);
var fieldNamesSpan = fieldNames.AsSpan().Slice(0, 1);
fieldNamesSpan[0] = fieldName;

var result = CollectSelections(fieldNamesSpan, out buffer, out size);
ArrayPool<string>.Shared.Return(fieldNames);
return result;
}

private bool CollectSelections(
ReadOnlySpan<string> fieldNames,
out ISelection[] buffer,
out int size)
{
buffer = ArrayPool<ISelection>.Shared.Rent(4);
size = 0;
Expand All @@ -389,7 +432,7 @@ private bool CollectSelections(string fieldName, out ISelection[] buffer, out in
{
if (!start.Type.IsCompositeType())
{
continue;
goto NEXT;
}

var namedType = start.Type.NamedType();
Expand All @@ -399,17 +442,18 @@ private bool CollectSelections(string fieldName, out ISelection[] buffer, out in
foreach (var possibleType in _schema.GetPossibleTypes(namedType))
{
var selectionSet = _operation.GetSelectionSet(start, possibleType);
Collect(ref buffer, selectionSet, size, out var written);
CollectFields(fieldNames, includeFlags, ref buffer, selectionSet, size, out var written);
size += written;
}
}
else
{
var selectionSet = _operation.GetSelectionSet(start, (ObjectType)namedType);
Collect(ref buffer, selectionSet, size, out var written);
CollectFields(fieldNames, includeFlags, ref buffer, selectionSet, size, out var written);
size += written;
}

NEXT:
start = ref Unsafe.Add(ref start, 1)!;
}

Expand All @@ -420,28 +464,69 @@ private bool CollectSelections(string fieldName, out ISelection[] buffer, out in
}

return size > 0;
}

private bool CollectSelections(
INamedType typeContext,
out ISelection[] buffer,
out int size)
{
buffer = ArrayPool<ISelection>.Shared.Rent(_selections.Length);
size = 0;

ref var start = ref MemoryMarshal.GetReference(_selections.AsSpan());
ref var end = ref Unsafe.Add(ref start, _selections.Length);

void Collect(ref ISelection[] buffer, ISelectionSet selectionSet, int index, out int written)
while (Unsafe.IsAddressLessThan(ref start, ref end))
{
written = 0;
if (typeContext.IsAssignableFrom(start.Type.NamedType()))
{
buffer[size++] = start;
}

var operationIncludeFlags = includeFlags;
var selectionCount = selectionSet.Selections.Count;
ref var selectionRef = ref ((SelectionSet)selectionSet).GetSelectionsReference();
start = ref Unsafe.Add(ref start, 1)!;
}

EnsureCapacity(ref buffer, index, selectionCount);
if (size == 0)
{
ArrayPool<ISelection>.Shared.Return(buffer);
buffer = Array.Empty<ISelection>();
}

for (var i = 0; i < selectionCount; i++)
{
var childSelection = Unsafe.Add(ref selectionRef, i);
return size > 0;
}

private static void CollectFields(
ReadOnlySpan<string> fieldNames,
long includeFlags,
ref ISelection[] buffer,
ISelectionSet selectionSet,
int index,
out int written)
{
written = 0;

var operationIncludeFlags = includeFlags;
var selectionCount = selectionSet.Selections.Count;

ref var selectionRef = ref ((SelectionSet)selectionSet).GetSelectionsReference();
ref var end = ref Unsafe.Add(ref selectionRef, selectionCount);

EnsureCapacity(ref buffer, index, selectionCount);

if (childSelection.IsIncluded(operationIncludeFlags) &&
childSelection.Field.Name.EqualsOrdinal(fieldName))
while (Unsafe.IsAddressLessThan(ref selectionRef, ref end))
{
foreach (var fieldName in fieldNames)
{
if (selectionRef.IsIncluded(operationIncludeFlags) &&
selectionRef.Field.Name.EqualsOrdinal(fieldName))
{
buffer[index++] = childSelection;
buffer[index++] = selectionRef;
written++;
}
}

selectionRef = ref Unsafe.Add(ref selectionRef, 1)!;
}
}

Expand Down Expand Up @@ -470,4 +555,19 @@ public IEnumerator<ISelection> GetEnumerator()

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
}

private sealed class Any : INamedType
{
public TypeKind Kind => default!;

public string Name => default!;

public string Description => default!;

public IReadOnlyDictionary<string, object?> ContextData => default!;

public bool IsAssignableFrom(INamedType type) => true;

public static readonly Any Instance = new Any();
}
}