From ed9a4e880297a8f83311d828efe8f21207dfe2a5 Mon Sep 17 00:00:00 2001 From: Michael Staib Date: Mon, 16 Jan 2023 12:58:02 +0100 Subject: [PATCH] Fixed DataLoader as Resolver leak --- .../Generators/DataLoaderGenerator.cs | 84 +++++-------------- .../Inspectors/DataLoaderInfo.cs | 57 +++++++++++++ .../Types.Analyzers/TypeModuleGenerator.cs | 2 +- .../Conventions/DefaultTypeInspector.cs | 8 ++ .../test/Types.Analyzers.Tests/SomeQuery.cs | 14 +++- 5 files changed, 100 insertions(+), 65 deletions(-) diff --git a/src/HotChocolate/Core/src/Types.Analyzers/Generators/DataLoaderGenerator.cs b/src/HotChocolate/Core/src/Types.Analyzers/Generators/DataLoaderGenerator.cs index 91dd664f9ad..c202428d503 100644 --- a/src/HotChocolate/Core/src/Types.Analyzers/Generators/DataLoaderGenerator.cs +++ b/src/HotChocolate/Core/src/Types.Analyzers/Generators/DataLoaderGenerator.cs @@ -133,7 +133,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) new DataLoaderDefaultsInfo(null, null, true); var processed = new HashSet(StringComparer.Ordinal); - var dataLoaders = new List<(string Interface, string Class)>(); + var dataLoaders = new List(); var sourceText = new StringBuilder(); sourceText.AppendLine("// "); @@ -158,7 +158,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) } if (dataLoader.MethodSymbol.DeclaredAccessibility is not Accessibility.Public - or Accessibility.Internal or Accessibility.ProtectedAndInternal) + and not Accessibility.Internal and not Accessibility.ProtectedAndInternal) { context.ReportDiagnostic( Diagnostic.Create( @@ -169,15 +169,6 @@ public bool Consume(ISyntaxInfo syntaxInfo) continue; } - var dataLoaderAttribute = dataLoader.MethodSymbol.GetDataLoaderAttribute(); - - var dataLoaderName = GetDataLoaderName( - dataLoader.MethodSymbol.Name, - dataLoaderAttribute); - - var isScoped = dataLoaderAttribute.IsScoped() ?? defaults.Scoped ?? false; - var isPublic = dataLoaderAttribute.IsPublic() ?? defaults.IsPublic ?? true; - var keyArg = dataLoader.MethodSymbol.Parameters[0]; var keyType = keyArg.Type; var cancellationTokenIndex = -1; @@ -209,25 +200,18 @@ public bool Consume(ISyntaxInfo syntaxInfo) } var valueType = ExtractValueType(dataLoader.MethodSymbol.ReturnType, kind); - var dataLoaderMethod = dataLoader.MethodSymbol; - var containingNamespace = dataLoaderMethod.ContainingNamespace.ToDisplayString(); - var dataLoaderFullName = containingNamespace + "." + dataLoaderName; - if (processed.Add(dataLoaderFullName)) + if (processed.Add(dataLoader.FullName)) { - dataLoaders.Add( - (containingNamespace + ".I" + dataLoaderName, - dataLoaderFullName)); + dataLoaders.Add(dataLoader); GenerateDataLoader( - dataLoaderName, + dataLoader, + defaults, kind, - isPublic, - isScoped, - dataLoaderMethod, keyType, valueType, - dataLoaderMethod.Parameters.Length, + dataLoader.MethodSymbol.Parameters.Length, cancellationTokenIndex, serviceMap, sourceText); @@ -255,11 +239,9 @@ public bool Consume(ISyntaxInfo syntaxInfo) } private static void GenerateDataLoader( - string name, + DataLoaderInfo dataLoader, + DataLoaderDefaultsInfo defaults, DataLoaderKind kind, - bool isPublic, - bool isScoped, - IMethodSymbol method, ITypeSymbol keyType, ITypeSymbol valueType, int parameterCount, @@ -267,13 +249,16 @@ public bool Consume(ISyntaxInfo syntaxInfo) Dictionary services, StringBuilder sourceText) { + var isScoped = dataLoader.IsScoped ?? defaults.Scoped ?? false; + var isPublic = dataLoader.IsPublic ?? defaults.IsPublic ?? true; + sourceText.AppendLine(); sourceText.Append("namespace "); - sourceText.AppendLine(method.ContainingNamespace.ToDisplayString()); + sourceText.AppendLine(dataLoader.Namespace); sourceText.AppendLine("{"); // first we generate a DataLoader interface ... - var interfaceName = "I" + name; + var interfaceName = dataLoader.InterfaceName; if (isPublic) { @@ -320,7 +305,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) sourceText.Append(" internal sealed class "); } - sourceText.Append(name); + sourceText.Append(dataLoader.Name); if (kind is DataLoaderKind.Batch) { @@ -365,7 +350,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) sourceText .Append(Indent) .Append(Indent) - .AppendLine($"public {name}("); + .AppendLine($"public {dataLoader.Name}("); sourceText .Append(Indent) .Append(Indent) @@ -393,7 +378,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) sourceText .Append(Indent) .Append(Indent) - .AppendLine($"public {name}("); + .AppendLine($"public {dataLoader.Name}("); sourceText .Append(Indent) .Append(Indent) @@ -487,9 +472,9 @@ public bool Consume(ISyntaxInfo syntaxInfo) } sourceText.Append(" return await global::"); - sourceText.Append(ToTypeName(method.ContainingType)); + sourceText.Append(dataLoader.ContainingType); sourceText.Append("."); - sourceText.Append(method.Name); + sourceText.Append(dataLoader.MethodName); sourceText.Append("("); for (var i = 0; i < parameterCount; i++) @@ -529,7 +514,7 @@ public bool Consume(ISyntaxInfo syntaxInfo) private static void GenerateDataLoaderRegistrations( ModuleInfo module, - List<(string Interface, string Class)> dataLoaders, + List dataLoaders, StringBuilder sourceText) { sourceText.Append(Indent) @@ -560,9 +545,9 @@ public bool Consume(ISyntaxInfo syntaxInfo) .Append(Indent) .Append("builder.AddDataLoader<") .Append("global::") - .Append(dataLoader.Interface) + .Append(dataLoader.InterfaceFullName) .Append(", global::") - .Append(dataLoader.Class) + .Append(dataLoader.FullName) .AppendLine(">();"); } @@ -576,31 +561,6 @@ public bool Consume(ISyntaxInfo syntaxInfo) .AppendLine("}"); } - private static string GetDataLoaderName(string name, AttributeData attribute) - { - if (attribute.TryGetName(out var s)) - { - return s; - } - - if (name.StartsWith("Get")) - { - name = name.Substring(3); - } - - if (name.EndsWith("Async")) - { - name = name.Substring(0, name.Length - 5); - } - - if (name.EndsWith("DataLoader")) - { - return name; - } - - return name + "DataLoader"; - } - private void InspectDataLoaderParameters( DataLoaderInfo dataLoader, ref int cancellationTokenIndex, diff --git a/src/HotChocolate/Core/src/Types.Analyzers/Inspectors/DataLoaderInfo.cs b/src/HotChocolate/Core/src/Types.Analyzers/Inspectors/DataLoaderInfo.cs index c7539495232..196559b3303 100644 --- a/src/HotChocolate/Core/src/Types.Analyzers/Inspectors/DataLoaderInfo.cs +++ b/src/HotChocolate/Core/src/Types.Analyzers/Inspectors/DataLoaderInfo.cs @@ -15,8 +15,40 @@ public sealed class DataLoaderInfo : ISyntaxInfo, IEquatable AttributeSymbol = attributeSymbol; MethodSymbol = methodSymbol; MethodSyntax = methodSyntax; + + var attribute = methodSymbol.GetDataLoaderAttribute(); + + Name = GetDataLoaderName(methodSymbol.Name, attribute); + InterfaceName = $"I{Name}"; + Namespace = methodSymbol.ContainingNamespace.ToDisplayString(); + FullName = $"{Namespace}.{Name}"; + InterfaceFullName = $"{Namespace}.{InterfaceName}"; + IsScoped = attribute.IsScoped(); + IsPublic = attribute.IsPublic(); + MethodName = methodSymbol.Name; + + var type = methodSymbol.ContainingType; + ContainingType = $"{type.ContainingNamespace}.{type.Name}"; } + public string Name { get; } + + public string FullName { get; } + + public string Namespace { get; } + + public string InterfaceName { get; } + + public string InterfaceFullName { get; } + + public string ContainingType { get; } + + public string MethodName { get; } + + public bool? IsScoped { get; } + + public bool? IsPublic { get; } + public AttributeSyntax AttributeSyntax { get; } public IMethodSymbol AttributeSymbol { get; } @@ -56,4 +88,29 @@ public override int GetHashCode() return hashCode; } } + + private static string GetDataLoaderName(string name, AttributeData attribute) + { + if (attribute.TryGetName(out var s)) + { + return s; + } + + if (name.StartsWith("Get")) + { + name = name.Substring(3); + } + + if (name.EndsWith("Async")) + { + name = name.Substring(0, name.Length - 5); + } + + if (name.EndsWith("DataLoader")) + { + return name; + } + + return name + "DataLoader"; + } } diff --git a/src/HotChocolate/Core/src/Types.Analyzers/TypeModuleGenerator.cs b/src/HotChocolate/Core/src/Types.Analyzers/TypeModuleGenerator.cs index e456e09d2ba..b0190076220 100644 --- a/src/HotChocolate/Core/src/Types.Analyzers/TypeModuleGenerator.cs +++ b/src/HotChocolate/Core/src/Types.Analyzers/TypeModuleGenerator.cs @@ -106,7 +106,7 @@ private static bool IsAssemblyAttributeList(SyntaxNode node) foreach (var syntaxGenerator in _generators) { // gather infos for current generator - for (var i = 0; i < all.Length; i++) + for (var i = all.Length - 1; i >= 0; i--) { var syntaxInfo = all[i]; diff --git a/src/HotChocolate/Core/src/Types/Types/Descriptors/Conventions/DefaultTypeInspector.cs b/src/HotChocolate/Core/src/Types/Types/Descriptors/Conventions/DefaultTypeInspector.cs index 3f56d88ae73..f9bd6b209bd 100644 --- a/src/HotChocolate/Core/src/Types/Types/Descriptors/Conventions/DefaultTypeInspector.cs +++ b/src/HotChocolate/Core/src/Types/Types/Descriptors/Conventions/DefaultTypeInspector.cs @@ -33,10 +33,12 @@ public class DefaultTypeInspector : Convention, ITypeInspector private readonly TypeCache _typeCache = new(); private readonly Dictionary _methods = new(); private readonly ConcurrentDictionary<(Type, bool, bool), MemberInfo[]> _memberCache = new(); + private readonly Type? _dataLoaderAttribute; public DefaultTypeInspector(bool ignoreRequiredAttribute = false) { IgnoreRequiredAttribute = ignoreRequiredAttribute; + _dataLoaderAttribute = Type.GetType("HotChocolate.DataLoaderAttribute"); } /// @@ -723,6 +725,12 @@ private bool CanBeHandled(MemberInfo member, bool includeIgnored) } } + if (_dataLoaderAttribute is not null && + method.IsDefined(_dataLoaderAttribute)) + { + return false; + } + return true; } diff --git a/src/HotChocolate/Core/test/Types.Analyzers.Tests/SomeQuery.cs b/src/HotChocolate/Core/test/Types.Analyzers.Tests/SomeQuery.cs index f42dd266c95..99944d2722f 100644 --- a/src/HotChocolate/Core/test/Types.Analyzers.Tests/SomeQuery.cs +++ b/src/HotChocolate/Core/test/Types.Analyzers.Tests/SomeQuery.cs @@ -28,6 +28,16 @@ public Task WithDataLoader(IFoosByIdDataLoader foosById) { return foosById.LoadAsync("a"); } + + // should be ignored on the schema + [DataLoader] + public static async Task GetFoosById55( + string id, + SomeService someService, + CancellationToken cancellationToken) + { + return "abc"; + } } [MutationType] @@ -45,7 +55,7 @@ public static class SomeSubscription public static class DataLoaderGen { [DataLoader] - public static async Task> GetFoosById( + internal static async Task> GetFoosById( IReadOnlyList ids, SomeService someService, CancellationToken cancellationToken) @@ -68,7 +78,7 @@ public static class DataLoaderGen SomeService someService, CancellationToken cancellationToken) { - return default; + return default!; } }