Skip to content

Commit

Permalink
Added Support for Subscription Arguments on the Stream Factory (#5691)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Jan 16, 2023
1 parent 8848df5 commit f8b573e
Show file tree
Hide file tree
Showing 11 changed files with 732 additions and 523 deletions.
Expand Up @@ -238,12 +238,22 @@ internal sealed class DefaultResolverCompiler : IResolverCompiler

if (member is MethodInfo method)
{
var parameters = method.GetParameters();
var owner = CreateResolverOwner(_context, sourceType, resolverType);
var parameterExpr = CreateParameters(_context, parameters, _empty);
Expression subscribeResolver = Call(owner, method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
if (method.IsStatic)
{
var parameterExpr = CreateParameters(_context, method.GetParameters(), _empty);
Expression subscribeResolver = Call(method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
}
else
{
var parameters = method.GetParameters();
var owner = CreateResolverOwner(_context, sourceType, resolverType);
var parameterExpr = CreateParameters(_context, parameters, _empty);
Expression subscribeResolver = Call(owner, method, parameterExpr);
subscribeResolver = EnsureSubscribeResult(subscribeResolver, method.ReturnType);
return Lambda<SubscribeResolverDelegate>(subscribeResolver, _context).Compile();
}
}

throw new ArgumentException(
Expand Down
Expand Up @@ -62,19 +62,23 @@ public sealed class SubscribeAttribute : ObjectFieldDescriptorAttribute
}
else
{

descriptor.Extend().OnBeforeCreate(d =>
{
var subscribeResolver = member.DeclaringType?.GetMethod(With!, Public | Instance);
if (subscribeResolver is null)
descriptor.Extend().OnBeforeCreate(
d =>
{
throw SubscribeAttribute_SubscribeResolverNotFound(member, With);
}
d.SubscribeResolver = context.ResolverCompiler.CompileSubscribe(
subscribeResolver, d.SourceType!, d.ResolverType);
});
var subscribeResolver = member.DeclaringType?.GetMethod(
With!,
Public | NonPublic | Instance | Static);
if (subscribeResolver is null)
{
throw SubscribeAttribute_SubscribeResolverNotFound(member, With);
}
d.SubscribeResolver = context.ResolverCompiler.CompileSubscribe(
subscribeResolver,
d.SourceType!,
d.ResolverType);
});
}
}

Expand All @@ -88,7 +92,9 @@ private static string ResolveTopicString(MethodInfo method)
return method.Name;
}

private static void SubscribeFactory<TMessage>(ObjectFieldDefinition fieldDef, string topicString)
private static void SubscribeFactory<TMessage>(
ObjectFieldDefinition fieldDef,
string topicString)
{
var arg = false;

Expand Down Expand Up @@ -125,10 +131,10 @@ private static void SubscribeFactory<TMessage>(ObjectFieldDefinition fieldDef, s
var ct = ctx.RequestAborted;
var receiver = ctx.Service<ITopicEventReceiver>();
return await receiver.SubscribeAsync<TMessage>(
topicString,
null,
null,
ct)
topicString,
null,
null,
ct)
.ConfigureAwait(false);
};
}
Expand All @@ -154,10 +160,10 @@ private static void SubscribeFactory<TMessage>(ObjectFieldDefinition fieldDef, s
// last we subscribe with the topic string.
var receiver = ctx.Service<ITopicEventReceiver>();
return await receiver.SubscribeAsync<TMessage>(
topicString,
null,
null,
ct)
topicString,
null,
null,
ct)
.ConfigureAwait(false);
};
}
Expand Down
Expand Up @@ -85,6 +85,11 @@ public ObjectFieldDefinition()
/// </summary>
public Type? ResultType { get; set; }

/// <summary>
/// The member name that represents the event stream factory.
/// </summary>
public string? SubscribeWith { get; set; }

/// <summary>
/// The delegate that represents the resolver.
/// </summary>
Expand Down Expand Up @@ -315,6 +320,7 @@ internal void CopyTo(ObjectFieldDefinition target)
target.IsIntrospectionField = IsIntrospectionField;
target.IsParallelExecutable = IsParallelExecutable;
target.HasStreamResult = HasStreamResult;
target.SubscribeWith = SubscribeWith;
}

internal void MergeInto(ObjectFieldDefinition target)
Expand Down Expand Up @@ -396,6 +402,11 @@ internal void MergeInto(ObjectFieldDefinition target)
{
target.SubscribeResolver = SubscribeResolver;
}

if (SubscribeWith is not null)
{
target.SubscribeWith = SubscribeWith;
}
}

private static void CleanMiddlewareDefinitions<T>(
Expand Down
Expand Up @@ -286,7 +286,6 @@ protected internal void MergeInto(ObjectTypeDefinition target)
newField.SourceType = target.RuntimeType;

SetResolverMember(newField, targetField);

target.Fields.Add(newField);
}
else
Expand Down
Expand Up @@ -11,6 +11,7 @@ public class InterfaceFieldDescriptor
: OutputFieldDescriptorBase<InterfaceFieldDefinition>
, IInterfaceFieldDescriptor
{
private ParameterInfo[] _parameterInfos = Array.Empty<ParameterInfo>();
private bool _argumentsInitialized;

protected internal InterfaceFieldDescriptor(
Expand Down Expand Up @@ -48,7 +49,8 @@ public class InterfaceFieldDescriptor

if (member is MethodInfo m)
{
Parameters = m.GetParameters().ToDictionary(t => t.Name, StringComparer.Ordinal);
_parameterInfos = m.GetParameters();
Parameters = _parameterInfos.ToDictionary(t => t.Name, StringComparer.Ordinal);
}
}

Expand All @@ -75,6 +77,7 @@ private void CompleteArguments(InterfaceFieldDefinition definition)
Context,
definition.Arguments,
definition.Member,
_parameterInfos,
definition.GetParameterExpressionBuilders());
_argumentsInitialized = true;
}
Expand Down
Expand Up @@ -10,6 +10,7 @@
using HotChocolate.Types.Descriptors.Definitions;
using HotChocolate.Types.Helpers;
using HotChocolate.Utilities;
using static System.Reflection.BindingFlags;
using static HotChocolate.Execution.ExecutionStrategy;

#nullable enable
Expand Down Expand Up @@ -157,17 +158,58 @@ protected override void OnCreateDefinition(ObjectFieldDefinition definition)

private void CompleteArguments(ObjectFieldDefinition definition)
{
if (!_argumentsInitialized && Parameters.Count > 0)
if (!_argumentsInitialized)
{
Context.ResolverCompiler.ApplyConfiguration(
_parameterInfos,
this);

FieldDescriptorUtilities.DiscoverArguments(
Context,
definition.Arguments,
definition.Member,
definition.GetParameterExpressionBuilders());
if (definition.SubscribeWith is not null)
{
var ownerType = definition.ResolverType ?? definition.SourceType;

if (ownerType is not null)
{
var subscribeMember = ownerType.GetMember(
definition.SubscribeWith,
Public | NonPublic | Instance | Static)[0];

if (subscribeMember is MethodInfo subscribeMethod)
{
var subscribeParameters = subscribeMethod.GetParameters();
var parameterLength = _parameterInfos.Length + subscribeParameters.Length;
var parameters = new ParameterInfo[parameterLength];

_parameterInfos.CopyTo(parameters, 0);
subscribeParameters.CopyTo(parameters, _parameterInfos.Length);
_parameterInfos = parameters;

var parameterLookup = Parameters.ToDictionary(
t => t.Key,
t => t.Value,
StringComparer.Ordinal);
Parameters = parameterLookup;

foreach (var parameter in subscribeParameters)
{
if (!parameterLookup.ContainsKey(parameter.Name!))
{
parameterLookup.Add(parameter.Name!, parameter);
}
}
}
}
}

if (Parameters.Count > 0)
{
Context.ResolverCompiler.ApplyConfiguration(
_parameterInfos,
this);

FieldDescriptorUtilities.DiscoverArguments(
Context,
definition.Arguments,
definition.Member,
_parameterInfos,
definition.GetParameterExpressionBuilders());
}

_argumentsInitialized = true;
}
Expand Down
Expand Up @@ -15,7 +15,7 @@ namespace HotChocolate.Types.Descriptors;

public class ObjectTypeDescriptor
: DescriptorBase<ObjectTypeDefinition>
, IObjectTypeDescriptor
, IObjectTypeDescriptor
{
private readonly List<ObjectFieldDescriptor> _fields = new();

Expand Down Expand Up @@ -130,7 +130,10 @@ internal void InferFieldsFromFieldBindingType()
IDictionary<string, ObjectFieldDefinition> fields,
ISet<MemberInfo> handledMembers)
{
HashSet<string>? subscribeResolver = null;
var skip = false;
HashSet<string>? subscribeRes = null;
Dictionary<MemberInfo, string>? subscribeResLook = null;


if (Definition.Fields.IsImplicitBinding() &&
Definition.FieldBindingType is not null)
Expand All @@ -148,14 +151,20 @@ internal void InferFieldsFromFieldBindingType()

if (handledMembers.Add(member) &&
!fields.ContainsKey(name) &&
IncludeField(ref subscribeResolver, members, member))
IncludeField(ref skip, ref subscribeRes, ref subscribeResLook, members, member))
{
var descriptor = ObjectFieldDescriptor.New(
Context,
member,
Definition.RuntimeType,
type);

if (subscribeResLook is not null &&
subscribeResLook.TryGetValue(member, out var with))
{
descriptor.Definition.SubscribeWith = with;
}

if (isExtension && inspector.IsMemberIgnored(member))
{
descriptor.Ignore();
Expand All @@ -173,35 +182,38 @@ internal void InferFieldsFromFieldBindingType()
}

static bool IncludeField(
ref bool skip,
ref HashSet<string>? subscribeResolver,
ref Dictionary<MemberInfo, string>? subscribeResolverLookup,
ReadOnlySpan<MemberInfo> allMembers,
MemberInfo current)
{
if (subscribeResolver is null)
// if there is now with declared we can include all members.
if (skip)
{
subscribeResolver = new HashSet<string>();
return true;
}

if (subscribeResolver is null)
{
foreach (var member in allMembers)
{
HandlePossibleSubscribeMember(subscribeResolver, member);
if (member.IsDefined(typeof(SubscribeAttribute)) &&
member.GetCustomAttribute<SubscribeAttribute>() is { With: not null } a)
{
subscribeResolver ??= new HashSet<string>();
subscribeResolverLookup ??= new Dictionary<MemberInfo, string>();
subscribeResolver.Add(a.With);
subscribeResolverLookup.Add(member, a.With);
}
}

skip = subscribeResolver is null;
}

return !subscribeResolver.Contains(current.Name);
return !subscribeResolver?.Contains(current.Name) ?? true;
}

static void HandlePossibleSubscribeMember(
HashSet<string> subscribeResolver,
MemberInfo member)
{
if (member.IsDefined(typeof(SubscribeAttribute)))
{
if (member.GetCustomAttribute<SubscribeAttribute>() is { With: not null } attr)
{
subscribeResolver.Add(attr.With);
}
}
}
}

protected virtual void OnCompleteFields(
Expand Down
Expand Up @@ -99,14 +99,15 @@ public static class FieldDescriptorUtilities
IDescriptorContext context,
ICollection<ArgumentDefinition> arguments,
MemberInfo? member,
ParameterInfo[] parameters,
IReadOnlyList<IParameterExpressionBuilder>? parameterExpressionBuilders)
{
if (arguments is null)
{
throw new ArgumentNullException(nameof(arguments));
}

if (member is MethodInfo method)
if (member is MethodInfo)
{
var processedNames = TypeMemHelper.RentNameSet();

Expand All @@ -122,7 +123,7 @@ public static class FieldDescriptorUtilities

foreach (var parameter in
context.ResolverCompiler.GetArgumentParameters(
method.GetParameters(),
parameters,
parameterExpressionBuilders))
{
var argumentDefinition =
Expand Down

0 comments on commit f8b573e

Please sign in to comment.