Skip to content

Commit

Permalink
Fixed DataLoader as Resolver leak
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Jan 16, 2023
1 parent da255e0 commit ed9a4e8
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 65 deletions.
Expand Up @@ -133,7 +133,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)
new DataLoaderDefaultsInfo(null, null, true);

var processed = new HashSet<string>(StringComparer.Ordinal);
var dataLoaders = new List<(string Interface, string Class)>();
var dataLoaders = new List<DataLoaderInfo>();
var sourceText = new StringBuilder();

sourceText.AppendLine("// <auto-generated/>");
Expand All @@ -158,7 +158,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)
}

if (dataLoader.MethodSymbol.DeclaredAccessibility is not Accessibility.Public
or Accessibility.Internal or Accessibility.ProtectedAndInternal)
and not Accessibility.Internal and not Accessibility.ProtectedAndInternal)
{
context.ReportDiagnostic(
Diagnostic.Create(
Expand All @@ -169,15 +169,6 @@ public bool Consume(ISyntaxInfo syntaxInfo)
continue;
}

var dataLoaderAttribute = dataLoader.MethodSymbol.GetDataLoaderAttribute();

var dataLoaderName = GetDataLoaderName(
dataLoader.MethodSymbol.Name,
dataLoaderAttribute);

var isScoped = dataLoaderAttribute.IsScoped() ?? defaults.Scoped ?? false;
var isPublic = dataLoaderAttribute.IsPublic() ?? defaults.IsPublic ?? true;

var keyArg = dataLoader.MethodSymbol.Parameters[0];
var keyType = keyArg.Type;
var cancellationTokenIndex = -1;
Expand Down Expand Up @@ -209,25 +200,18 @@ public bool Consume(ISyntaxInfo syntaxInfo)
}

var valueType = ExtractValueType(dataLoader.MethodSymbol.ReturnType, kind);
var dataLoaderMethod = dataLoader.MethodSymbol;
var containingNamespace = dataLoaderMethod.ContainingNamespace.ToDisplayString();
var dataLoaderFullName = containingNamespace + "." + dataLoaderName;

if (processed.Add(dataLoaderFullName))
if (processed.Add(dataLoader.FullName))
{
dataLoaders.Add(
(containingNamespace + ".I" + dataLoaderName,
dataLoaderFullName));
dataLoaders.Add(dataLoader);

GenerateDataLoader(
dataLoaderName,
dataLoader,
defaults,
kind,
isPublic,
isScoped,
dataLoaderMethod,
keyType,
valueType,
dataLoaderMethod.Parameters.Length,
dataLoader.MethodSymbol.Parameters.Length,
cancellationTokenIndex,
serviceMap,
sourceText);
Expand Down Expand Up @@ -255,25 +239,26 @@ public bool Consume(ISyntaxInfo syntaxInfo)
}

private static void GenerateDataLoader(
string name,
DataLoaderInfo dataLoader,
DataLoaderDefaultsInfo defaults,
DataLoaderKind kind,
bool isPublic,
bool isScoped,
IMethodSymbol method,
ITypeSymbol keyType,
ITypeSymbol valueType,
int parameterCount,
int cancelIndex,
Dictionary<int, string> services,
StringBuilder sourceText)
{
var isScoped = dataLoader.IsScoped ?? defaults.Scoped ?? false;
var isPublic = dataLoader.IsPublic ?? defaults.IsPublic ?? true;

sourceText.AppendLine();
sourceText.Append("namespace ");
sourceText.AppendLine(method.ContainingNamespace.ToDisplayString());
sourceText.AppendLine(dataLoader.Namespace);
sourceText.AppendLine("{");

// first we generate a DataLoader interface ...
var interfaceName = "I" + name;
var interfaceName = dataLoader.InterfaceName;

if (isPublic)
{
Expand Down Expand Up @@ -320,7 +305,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)
sourceText.Append(" internal sealed class ");
}

sourceText.Append(name);
sourceText.Append(dataLoader.Name);

if (kind is DataLoaderKind.Batch)
{
Expand Down Expand Up @@ -365,7 +350,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)
sourceText
.Append(Indent)
.Append(Indent)
.AppendLine($"public {name}(");
.AppendLine($"public {dataLoader.Name}(");
sourceText
.Append(Indent)
.Append(Indent)
Expand Down Expand Up @@ -393,7 +378,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)
sourceText
.Append(Indent)
.Append(Indent)
.AppendLine($"public {name}(");
.AppendLine($"public {dataLoader.Name}(");
sourceText
.Append(Indent)
.Append(Indent)
Expand Down Expand Up @@ -487,9 +472,9 @@ public bool Consume(ISyntaxInfo syntaxInfo)
}

sourceText.Append(" return await global::");
sourceText.Append(ToTypeName(method.ContainingType));
sourceText.Append(dataLoader.ContainingType);
sourceText.Append(".");
sourceText.Append(method.Name);
sourceText.Append(dataLoader.MethodName);
sourceText.Append("(");

for (var i = 0; i < parameterCount; i++)
Expand Down Expand Up @@ -529,7 +514,7 @@ public bool Consume(ISyntaxInfo syntaxInfo)

private static void GenerateDataLoaderRegistrations(
ModuleInfo module,
List<(string Interface, string Class)> dataLoaders,
List<DataLoaderInfo> dataLoaders,
StringBuilder sourceText)
{
sourceText.Append(Indent)
Expand Down Expand Up @@ -560,9 +545,9 @@ public bool Consume(ISyntaxInfo syntaxInfo)
.Append(Indent)
.Append("builder.AddDataLoader<")
.Append("global::")
.Append(dataLoader.Interface)
.Append(dataLoader.InterfaceFullName)
.Append(", global::")
.Append(dataLoader.Class)
.Append(dataLoader.FullName)
.AppendLine(">();");
}

Expand All @@ -576,31 +561,6 @@ public bool Consume(ISyntaxInfo syntaxInfo)
.AppendLine("}");
}

private static string GetDataLoaderName(string name, AttributeData attribute)
{
if (attribute.TryGetName(out var s))
{
return s;
}

if (name.StartsWith("Get"))
{
name = name.Substring(3);
}

if (name.EndsWith("Async"))
{
name = name.Substring(0, name.Length - 5);
}

if (name.EndsWith("DataLoader"))
{
return name;
}

return name + "DataLoader";
}

private void InspectDataLoaderParameters(
DataLoaderInfo dataLoader,
ref int cancellationTokenIndex,
Expand Down
Expand Up @@ -15,8 +15,40 @@ public sealed class DataLoaderInfo : ISyntaxInfo, IEquatable<DataLoaderInfo>
AttributeSymbol = attributeSymbol;
MethodSymbol = methodSymbol;
MethodSyntax = methodSyntax;

var attribute = methodSymbol.GetDataLoaderAttribute();

Name = GetDataLoaderName(methodSymbol.Name, attribute);
InterfaceName = $"I{Name}";
Namespace = methodSymbol.ContainingNamespace.ToDisplayString();
FullName = $"{Namespace}.{Name}";
InterfaceFullName = $"{Namespace}.{InterfaceName}";
IsScoped = attribute.IsScoped();
IsPublic = attribute.IsPublic();
MethodName = methodSymbol.Name;

var type = methodSymbol.ContainingType;
ContainingType = $"{type.ContainingNamespace}.{type.Name}";
}

public string Name { get; }

public string FullName { get; }

public string Namespace { get; }

public string InterfaceName { get; }

public string InterfaceFullName { get; }

public string ContainingType { get; }

public string MethodName { get; }

public bool? IsScoped { get; }

public bool? IsPublic { get; }

public AttributeSyntax AttributeSyntax { get; }

public IMethodSymbol AttributeSymbol { get; }
Expand Down Expand Up @@ -56,4 +88,29 @@ public override int GetHashCode()
return hashCode;
}
}

private static string GetDataLoaderName(string name, AttributeData attribute)
{
if (attribute.TryGetName(out var s))
{
return s;
}

if (name.StartsWith("Get"))
{
name = name.Substring(3);
}

if (name.EndsWith("Async"))
{
name = name.Substring(0, name.Length - 5);
}

if (name.EndsWith("DataLoader"))
{
return name;
}

return name + "DataLoader";
}
}
Expand Up @@ -106,7 +106,7 @@ private static bool IsAssemblyAttributeList(SyntaxNode node)
foreach (var syntaxGenerator in _generators)
{
// gather infos for current generator
for (var i = 0; i < all.Length; i++)
for (var i = all.Length - 1; i >= 0; i--)
{
var syntaxInfo = all[i];

Expand Down
Expand Up @@ -33,10 +33,12 @@ public class DefaultTypeInspector : Convention, ITypeInspector
private readonly TypeCache _typeCache = new();
private readonly Dictionary<MemberInfo, ExtendedMethodInfo> _methods = new();
private readonly ConcurrentDictionary<(Type, bool, bool), MemberInfo[]> _memberCache = new();
private readonly Type? _dataLoaderAttribute;

public DefaultTypeInspector(bool ignoreRequiredAttribute = false)
{
IgnoreRequiredAttribute = ignoreRequiredAttribute;
_dataLoaderAttribute = Type.GetType("HotChocolate.DataLoaderAttribute");
}

/// <summary>
Expand Down Expand Up @@ -723,6 +725,12 @@ private bool CanBeHandled(MemberInfo member, bool includeIgnored)
}
}

if (_dataLoaderAttribute is not null &&
method.IsDefined(_dataLoaderAttribute))
{
return false;
}

return true;
}

Expand Down
14 changes: 12 additions & 2 deletions src/HotChocolate/Core/test/Types.Analyzers.Tests/SomeQuery.cs
Expand Up @@ -28,6 +28,16 @@ public Task<string> WithDataLoader(IFoosByIdDataLoader foosById)
{
return foosById.LoadAsync("a");
}

// should be ignored on the schema
[DataLoader]
public static async Task<string> GetFoosById55(
string id,
SomeService someService,
CancellationToken cancellationToken)
{
return "abc";
}
}

[MutationType]
Expand All @@ -45,7 +55,7 @@ public static class SomeSubscription
public static class DataLoaderGen
{
[DataLoader]
public static async Task<IReadOnlyDictionary<string, string>> GetFoosById(
internal static async Task<IReadOnlyDictionary<string, string>> GetFoosById(
IReadOnlyList<string> ids,
SomeService someService,
CancellationToken cancellationToken)
Expand All @@ -68,7 +78,7 @@ public static class DataLoaderGen
SomeService someService,
CancellationToken cancellationToken)
{
return default;
return default!;
}
}

Expand Down

0 comments on commit ed9a4e8

Please sign in to comment.